Example #1
0
def main(train_paths,
         test_paths=None,
         restart=None,
         ckpt_file=None,
         model_args=None,
         data_args=None,
         preprocess_args=None,
         train_args=None,
         proj_basis=None,
         seed=None,
         device=None):

    if seed is None:
        seed = np.random.randint(0, 2**32)
    print(f'# using seed: {seed}')
    np.random.seed(seed)
    torch.manual_seed(seed)

    if model_args is None: model_args = {}
    if data_args is None: data_args = {}
    if preprocess_args is None: preprocess_args = {}
    if train_args is None: train_args = {}
    if proj_basis is not None:
        model_args["proj_basis"] = proj_basis
    if ckpt_file is not None:
        train_args["ckpt_file"] = ckpt_file
    if device is not None:
        train_args["device"] = device

    train_paths = load_dirs(train_paths)
    # print(f'# training with {len(train_paths)} system(s)')
    g_reader = GroupReader(train_paths, **data_args)
    if test_paths is not None:
        test_paths = load_dirs(test_paths)
        # print(f'# testing with {len(test_paths)} system(s)')
        test_reader = GroupReader(test_paths, **data_args)
    else:
        print('# testing with training set')
        test_reader = None

    if restart is not None:
        model = CorrNet.load(restart)
    else:
        input_dim = g_reader.ndesc
        if model_args.get("input_dim", input_dim) != input_dim:
            print(f"# `input_dim` in `model_args` does not match data",
                  "({input_dim}).",
                  "Use the one in data.",
                  file=sys.stderr)
        model_args["input_dim"] = input_dim
        model = CorrNet(**model_args).double()

    preprocess(model, g_reader, **preprocess_args)
    train(model, g_reader, test_reader=test_reader, **train_args)
Example #2
0
def main(data_paths,
         model_file="model.pth",
         output_prefix='test',
         group=False,
         e_name='l_e_delta',
         d_name=['dm_eig']):
    data_paths = load_dirs(data_paths)
    g_reader = GroupReader(data_paths, e_name=e_name, d_name=d_name)
    model_file = check_list(model_file)
    for f in model_file:
        print(f)
        p = os.path.dirname(f)
        model = CorrNet.load(f).double().to(DEVICE)
        dump = os.path.join(p, output_prefix)
        dir_name = os.path.dirname(dump)
        if dir_name:
            os.makedirs(dir_name, exist_ok=True)
        test(model, g_reader, dump_prefix=dump, group=group)
Example #3
0
 def __init__(self, model, proj_basis=None, device=DEVICE):
     # make sure you call this method after the base SCF class init
     # otherwise it would throw an error due to the lack of mol attr
     self.device = device
     if isinstance(model, str):
         model = CorrNet.load(model).double()
     if isinstance(model, torch.nn.Module):
         model = model.to(self.device).eval()
     self.net = model
     # try load basis from model file
     if proj_basis is None:
         proj_basis = getattr(model, "_pbas", None)
     # should be a list here, follow pyscf convention
     self._pbas = load_basis(proj_basis)
     # [1,1,1,...,3,3,3,...,5,5,5,...]
     self._shell_sec = get_shell_sec(self._pbas)
     # total number of projected basis per atom
     self.nproj = sum(self._shell_sec)
     # prepare overlap integrals used in projection
     self.prepare_integrals()
Example #4
0
 def __init__(self, model, proj_basis=None, device=DEVICE):
     # make sure you call this method after the base SCF class init
     # otherwise it would throw an error due to the lack of mol attr
     rawmodel = model
     self.device = device
     if isinstance(model, str):
         model = CorrNet.load(model).double()
     if isinstance(model, torch.nn.Module):
         model = model.to(self.device)
     self.net = model
     # try load basis from model file
     if proj_basis is None and isinstance(rawmodel, str):
         mdict = torch.load(rawmodel, map_location="cpu")
         proj_basis = mdict.get("extra_info", {}).get("proj_basis", None)
     # should be a list here, follow pyscf convention
     self._pbas = load_basis(proj_basis)
     # [1,1,1,...,3,3,3,...,5,5,5,...]
     self._shell_sec = sum(
         ([2 * b[0] + 1] * (len(b) - 1) for b in self._pbas), [])
     # total number of projected basis per atom
     self.nproj = sum(self._shell_sec)
     # prepare overlap integrals used in projection
     self.prepare_integrals()
Example #5
0
def main(systems,
         model_file="model.pth",
         basis='ccpvdz',
         proj_basis=None,
         penalty_terms=None,
         device=None,
         dump_dir=".",
         dump_fields=DEFAULT_FNAMES,
         group=False,
         mol_args=None,
         scf_args=None,
         verbose=0):
    if model_file is None or model_file.upper() == "NONE":
        model = None
        default_scf_args = DEFAULT_HF_ARGS
    else:
        model = CorrNet.load(model_file).double()
        default_scf_args = DEFAULT_SCF_ARGS

    # check arguments
    penalty_terms = check_list(penalty_terms)
    if mol_args is None: mol_args = {}
    if scf_args is None: scf_args = {}
    scf_args = {**default_scf_args, **scf_args}
    fields = select_fields(dump_fields)
    # check label names from label fields and penalties
    label_names = get_required_labels(fields["scf"] + fields["grad"],
                                      penalty_terms)

    if verbose:
        print(f"starting calculation with OMP threads: {lib.num_threads()}",
              f"and max memory: {lib.param.MAX_MEMORY}")
        if verbose > 1:
            print(f"basis: {basis}")
            print(f"specified scf args:\n  {scf_args}")

    meta = old_meta = None
    res_list = []
    systems = load_sys_paths(systems)

    for fl in systems:
        fl = fl.rstrip(os.path.sep)
        for atom, attrs, labels in system_iter(fl, label_names):
            mol_input = {
                **mol_args, "verbose": verbose,
                "atom": atom,
                "basis": basis,
                **attrs
            }
            mol = build_mol(**mol_input)
            penalties = [build_penalty(pd, labels) for pd in penalty_terms]
            try:
                meta, result = solve_mol(mol,
                                         model,
                                         fields,
                                         labels,
                                         proj_basis=proj_basis,
                                         penalties=penalties,
                                         device=device,
                                         verbose=verbose,
                                         **scf_args)
            except Exception as e:
                print(fl, 'failed! error:', e, file=sys.stderr)
                # continue
                raise
            if group and old_meta is not None and np.any(meta != old_meta):
                break
            res_list.append(result)

        if not group:
            sub_dir = os.path.join(dump_dir,
                                   get_sys_name(os.path.basename(fl)))
            dump_meta(sub_dir, meta)
            dump_data(sub_dir, **collect_fields(fields, meta, res_list))
            res_list = []
        elif old_meta is not None and np.any(meta != old_meta):
            print(fl,
                  'meta does not match! saving previous results only.',
                  file=sys.stderr)
            break
        old_meta = meta
        if verbose:
            print(fl, 'finished')

    if group:
        dump_meta(dump_dir, meta)
        dump_data(dump_dir, **collect_fields(fields, meta, res_list))
        if verbose:
            print('group finished')