Beispiel #1
0
def _do_work(jdata, run_opt):
    # init the model
    model = NNPTrainer(jdata, run_opt=run_opt)
    rcut = model.model.get_rcut()
    type_map = model.model.get_type_map()
    # init params and run options
    assert ('training' in jdata)
    systems = j_must_have(jdata['training'], 'systems')
    if type(systems) == str:
        systems = expand_sys_str(systems)
    set_pfx = j_must_have(jdata['training'], 'set_prefix')
    seed = None
    if 'seed' in jdata['training'].keys(): seed = jdata['training']['seed']
    if seed is not None:
        seed = seed % (2**32)
    np.random.seed(seed)
    batch_size = j_must_have(jdata['training'], 'batch_size')
    test_size = j_must_have(jdata['training'], 'numb_test')
    stop_batch = j_must_have(jdata['training'], 'stop_batch')
    sys_probs = jdata['training'].get('sys_probs')
    auto_prob_style = jdata['training'].get('auto_prob_style', 'prob_sys_size')
    if len(type_map) == 0:
        # empty type_map
        ipt_type_map = None
    else:
        ipt_type_map = type_map
    # data modifier
    modifier = None
    modi_data = jdata['model'].get("modifier", None)
    if modi_data is not None:
        if modi_data['type'] == 'dipole_charge':
            modifier = DipoleChargeModifier(modi_data['model_name'],
                                            modi_data['model_charge_map'],
                                            modi_data['sys_charge_map'],
                                            modi_data['ewald_h'],
                                            modi_data['ewald_beta'])
        else:
            raise RuntimeError('unknown modifier type ' +
                               str(modi_data['type']))
    # init data
    data = DeepmdDataSystem(systems,
                            batch_size,
                            test_size,
                            rcut,
                            set_prefix=set_pfx,
                            type_map=ipt_type_map,
                            modifier=modifier)
    data.print_summary(run_opt,
                       sys_probs=sys_probs,
                       auto_prob_style=auto_prob_style)
    data.add_dict(data_requirement)
    # build the model with stats from the first system
    model.build(data, stop_batch)
    # train the model with the provided systems in a cyclic way
    start_time = time.time()
    model.train(data)
    end_time = time.time()
    run_opt.message("finished training\nwall time: %.3f s" %
                    (end_time - start_time))
Beispiel #2
0
def test(args):
    de = DeepEval(args.model)
    all_sys = expand_sys_str(args.system)
    if len(all_sys) == 0:
        print('Did not find valid system')
    err_coll = []
    siz_coll = []
    if de.model_type == 'ener':
        dp = DeepPot(args.model)
    elif de.model_type == 'dipole':
        dp = DeepDipole(args.model)
    elif de.model_type == 'polar':
        dp = DeepPolar(args.model)
    elif de.model_type == 'global_polar':
        dp = DeepGlobalPolar(args.model)
    elif de.model_type == 'wfc':
        dp = DeepWFC(args.model)
    else:
        raise RuntimeError('unknow model type ' + de.model_type)
    for cc, ii in enumerate(all_sys):
        args.system = ii
        print("# ---------------output of dp test--------------- ")
        print("# testing system : " + ii)
        if de.model_type == 'ener':
            err, siz = test_ener(dp, args, append_detail=(cc != 0))
        elif de.model_type == 'dipole':
            err, siz = test_dipole(dp, args)
        elif de.model_type == 'polar':
            err, siz = test_polar(dp, args, global_polar=False)
        elif de.model_type == 'global_polar':
            err, siz = test_polar(dp, args, global_polar=True)
        elif de.model_type == 'wfc':
            err, siz = test_wfc(dp, args)
        else:
            raise RuntimeError('unknow model type ' + de.model_type)
        print("# ----------------------------------------------- ")
        err_coll.append(err)
        siz_coll.append(siz)
    avg_err = weighted_average(err_coll, siz_coll)
    if len(all_sys) != len(err_coll):
        print('Not all systems are tested! Check if the systems are valid')
    if len(all_sys) > 1:
        print("# ----------weighted average of errors----------- ")
        print("# number of systems : %d" % len(all_sys))
        if de.model_type == 'ener':
            print_ener_sys_avg(avg_err)
        elif de.model_type == 'dipole':
            print_dipole_sys_avg(avg_err)
        elif de.model_type == 'polar':
            print_polar_sys_avg(avg_err)
        elif de.model_type == 'global_polar':
            print_polar_sys_avg(avg_err)
        elif de.model_type == 'wfc':
            print_wfc_sys_avg(avg_err)
        else:
            raise RuntimeError('unknow model type ' + de.model_type)
        print("# ----------------------------------------------- ")
Beispiel #3
0
def get_data(jdata: Dict[str, Any], rcut, type_map, modifier):
    systems = j_must_have(jdata, "systems")
    if isinstance(systems, str):
        systems = expand_sys_str(systems)
    help_msg = 'Please check your setting for data systems'
    # check length of systems
    if len(systems) == 0:
        msg = 'cannot find valid a data system'
        log.fatal(msg)
        raise IOError(msg, help_msg)
    # rougly check all items in systems are valid
    for ii in systems:
        ii = DPPath(ii)
        if (not ii.is_dir()):
            msg = f'dir {ii} is not a valid dir'
            log.fatal(msg)
            raise IOError(msg, help_msg)
        if (not (ii / 'type.raw').is_file()):
            msg = f'dir {ii} is not a valid data system dir'
            log.fatal(msg)
            raise IOError(msg, help_msg)

    batch_size = j_must_have(jdata, "batch_size")
    sys_probs = jdata.get("sys_probs", None)
    auto_prob = jdata.get("auto_prob", "prob_sys_size")

    data = DeepmdDataSystem(
        systems=systems,
        batch_size=batch_size,
        test_size=1,        # to satisfy the old api
        shuffle_test=True,  # to satisfy the old api
        rcut=rcut,
        type_map=type_map,
        modifier=modifier,
        trn_all_set=True,    # sample from all sets
        sys_probs=sys_probs,
        auto_prob_style=auto_prob
    )
    data.add_dict(data_requirement)

    return data
Beispiel #4
0
def neighbor_stat(
    *,
    system: str,
    rcut: float,
    type_map: List[str],
    **kwargs,
):
    """Calculate neighbor statistics.

    Parameters
    ----------
    system : str
        system to stat
    rcut : float
        cutoff radius
    type_map : list[str]
        type map

    Examples
    --------
    >>> neighbor_stat(system='.', rcut=6., type_map=["C", "H", "O", "N", "P", "S", "Mg", "Na", "HW", "OW", "mNa", "mCl", "mC", "mH", "mMg", "mN", "mO", "mP"])
    min_nbor_dist: 0.6599510670195264
    max_nbor_size: [23, 26, 19, 16, 2, 2, 1, 1, 72, 37, 5, 0, 31, 29, 1, 21, 20, 5]
    """ 
    all_sys = expand_sys_str(system)
    if not len(all_sys):
        raise RuntimeError("Did not find valid system")
    data = DeepmdDataSystem(
        systems=all_sys,
        batch_size=1,
        test_size=1,
        rcut=rcut,
        type_map=type_map,
    )
    data.get_batch()
    nei = NeighborStat(data.get_ntypes(), rcut)
    min_nbor_dist, max_nbor_size = nei.get_stat(data)
    log.info("min_nbor_dist: %f" % min_nbor_dist)
    log.info("max_nbor_size: %s" % str(max_nbor_size))
    return min_nbor_dist, max_nbor_size
Beispiel #5
0
 def test_expand(self):
     ret = expand_sys_str('test_sys')
     ret.sort()
     self.assertEqual(ret, self.expected_out)
Beispiel #6
0
def make_model_devi(*, models: list, system: str, set_prefix: str, output: str,
                    frequency: int, **kwargs):
    '''
    Make model deviation calculation

    Parameters
    ----------
    models: list
        A list of paths of models to use for making model deviation
    system: str
        The path of system to make model deviation calculation
    set_prefix: str
        The set prefix of the system
    output: str
        The output file for model deviation results
    frequency: int
        The number of steps that elapse between writing coordinates 
        in a trajectory by a MD engine (such as Gromacs / Lammps).
        This paramter is used to determine the index in the output file.
    '''
    auto_batch_size = AutoBatchSize()
    # init models
    dp_models = [
        DeepPot(model, auto_batch_size=auto_batch_size) for model in models
    ]

    # check type maps
    tmaps = [dp.get_type_map() for dp in dp_models]
    if _check_tmaps(tmaps):
        tmap = tmaps[0]
    else:
        raise RuntimeError("The models does not have the same type map.")

    all_sys = expand_sys_str(system)
    if len(all_sys) == 0:
        raise RuntimeError("Did not find valid system")
    devis_coll = []
    for system in all_sys:
        # create data-system
        dp_data = DeepmdData(system,
                             set_prefix,
                             shuffle_test=False,
                             type_map=tmap)
        if dp_data.pbc:
            nopbc = False
        else:
            nopbc = True

        data_sets = [dp_data._load_set(set_name) for set_name in dp_data.dirs]
        nframes_tot = 0
        devis = []
        for data in data_sets:
            coord = data["coord"]
            box = data["box"]
            atype = data["type"][0]
            devi = calc_model_devi(coord, box, atype, dp_models, nopbc=nopbc)
            nframes_tot += coord.shape[0]
            devis.append(devi)
        devis = np.vstack(devis)
        devis[:, 0] = np.arange(nframes_tot) * frequency
        write_model_devi_out(devis, output)
        devis_coll.append(devis)
    return devis_coll
Beispiel #7
0
def test(
    *,
    model: str,
    system: str,
    set_prefix: str,
    numb_test: int,
    rand_seed: Optional[int],
    shuffle_test: bool,
    detail_file: str,
    atomic: bool,
    **kwargs,
):
    """Test model predictions.

    Parameters
    ----------
    model : str
        path where model is stored
    system : str
        system directory
    set_prefix : str
        string prefix of set
    numb_test : int
        munber of tests to do
    rand_seed : Optional[int]
        seed for random generator
    shuffle_test : bool
        whether to shuffle tests
    detail_file : Optional[str]
        file where test details will be output
    atomic : bool
        whether per atom quantities should be computed

    Raises
    ------
    RuntimeError
        if no valid system was found
    """
    all_sys = expand_sys_str(system)
    if len(all_sys) == 0:
        raise RuntimeError("Did not find valid system")
    err_coll = []
    siz_coll = []

    # init random seed
    if rand_seed is not None:
        dp_random.seed(rand_seed % (2 ** 32))

    # init model
    dp = DeepPotential(model)

    for cc, system in enumerate(all_sys):
        log.info("# ---------------output of dp test--------------- ")
        log.info(f"# testing system : {system}")

        # create data class
        tmap = dp.get_type_map() if dp.model_type == "ener" else None
        data = DeepmdData(system, set_prefix, shuffle_test=shuffle_test, type_map=tmap)

        if dp.model_type == "ener":
            err = test_ener(
                dp,
                data,
                system,
                numb_test,
                detail_file,
                atomic,
                append_detail=(cc != 0),
            )
        elif dp.model_type == "dipole":
            err = test_dipole(dp, data, numb_test, detail_file, atomic)
        elif dp.model_type == "polar":
            err = test_polar(dp, data, numb_test, detail_file, atomic=atomic)
        elif dp.model_type == "global_polar":   # should not appear in this new version
            log.warning("Global polar model is not currently supported. Please directly use the polar mode and change loss parameters.")
            err = test_polar(dp, data, numb_test, detail_file, atomic=False)    # YWolfeee: downward compatibility
        log.info("# ----------------------------------------------- ")
        err_coll.append(err)

    avg_err = weighted_average(err_coll)

    if len(all_sys) != len(err_coll):
        log.warning("Not all systems are tested! Check if the systems are valid")

    if len(all_sys) > 1:
        log.info("# ----------weighted average of errors----------- ")
        log.info(f"# number of systems : {len(all_sys)}")
        if dp.model_type == "ener":
            print_ener_sys_avg(avg_err)
        elif dp.model_type == "dipole":
            print_dipole_sys_avg(avg_err)
        elif dp.model_type == "polar":
            print_polar_sys_avg(avg_err)
        elif dp.model_type == "global_polar":
            print_polar_sys_avg(avg_err)
        elif dp.model_type == "wfc":
            print_wfc_sys_avg(avg_err)
        log.info("# ----------------------------------------------- ")