Ejemplo n.º 1
0
def test_qm9(qm9_path, qm9_dataset):
    """
    Test if QM9 dataset object has same behaviour as AtomsData.

    """
    atoms_data = spk.AtomsData(qm9_path)
    assert_dataset_equal(atoms_data, qm9_dataset)
Ejemplo n.º 2
0
def get_dataset(args, train_args, environment_provider):
    load_only = [train_args.property]
    try:
        if train_args.derivative is not None:
            load_only.append(train_args.derivative)
    except:
        pass

    dataset = spk.AtomsData(
        args.dbpath,
        load_only=load_only,
        collect_triples=train_args.model == "wacsf",
        environment_provider=environment_provider,
    )
    return dataset
Ejemplo n.º 3
0
def get_dataset(args, environment_provider, logging=None):
    """
    Get dataset from arguments.

    Args:
        args (argparse.Namespace): parsed arguments
        environment_provider (spk.environment.BaseEnvironmentProvider): environment-
            provider of dataset
        logging: logger

    Returns:
        spk.data.AtomsData: dataset

    """
    if args.dataset == "qm9":
        if logging:
            logging.info("QM9 will be loaded...")
        qm9 = spk.datasets.QM9(
            args.datapath,
            download=True,
            load_only=[args.property],
            collect_triples=args.model == "wacsf",
            remove_uncharacterized=args.remove_uncharacterized,
            environment_provider=environment_provider,
        )
        return qm9
    elif args.dataset == "ani1":
        if logging:
            logging.info("ANI1 will be loaded...")
        ani1 = spk.datasets.ANI1(
            args.datapath,
            download=True,
            load_only=[args.property],
            collect_triples=args.model == "wacsf",
            num_heavy_atoms=args.num_heavy_atoms,
            environment_provider=environment_provider,
        )
        return ani1
    elif args.dataset == "md17":
        if logging:
            logging.info("MD17 will be loaded...")
        md17 = spk.datasets.MD17(
            args.datapath,
            args.molecule,
            download=True,
            collect_triples=args.model == "wacsf",
            environment_provider=environment_provider,
        )
        return md17
    elif args.dataset == "matproj":
        if logging:
            logging.info("Materials project will be loaded...")
        mp = spk.datasets.MaterialsProject(
            args.datapath,
            apikey=args.apikey,
            download=True,
            load_only=[args.property],
            environment_provider=environment_provider,
        )
        if args.timestamp:
            mp = mp.at_timestamp(args.timestamp)
        return mp
    elif args.dataset == "omdb":
        if logging:
            logging.info("Organic Materials Database will be loaded...")
        omdb = spk.datasets.OrganicMaterialsDatabase(
            args.datapath,
            download=True,
            load_only=[args.property],
            environment_provider=environment_provider,
        )
        return omdb
    elif args.dataset == "custom":
        if logging:
            logging.info("Custom dataset will be loaded...")

        # define properties to be loaded
        load_only = [args.property]
        if args.derivative is not None:
            load_only.append(args.derivative)

        dataset = spk.AtomsData(
            args.datapath,
            load_only=load_only,
            collect_triples=args.model == "wacsf",
            environment_provider=environment_provider,
        )
        return dataset
    else:
        raise spk.utils.ScriptError("Invalid dataset selected!")
Ejemplo n.º 4
0
import schnetpack as spk
import schnetpack.atomistic.model
from schnetpack.train import Trainer, CSVHook, ReduceLROnPlateauHook
from schnetpack.train.metrics import MeanAbsoluteError
from schnetpack.train.metrics import mse_loss

logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))

# basic settings
model_dir = "ethanol_model"  # directory that will be created for storing model
os.makedirs(model_dir)
properties = ["energy", "forces"]  # properties used for training

# data preparation
logging.info("get dataset")
dataset = spk.AtomsData("data/ethanol.db", load_only=properties)
train, val, test = spk.train_test_split(
    data=dataset,
    num_train=1000,
    num_val=100,
    split_file=os.path.join(model_dir, "split.npz"),
)
train_loader = spk.AtomsLoader(train, batch_size=64)
val_loader = spk.AtomsLoader(val, batch_size=64)

# get statistics
atomrefs = dataset.get_atomrefs(properties)
per_atom = dict(energy=True, forces=False)
means, stddevs = train_loader.get_statistics(properties,
                                             single_atom_ref=atomrefs,
                                             get_atomwise_statistics=per_atom)
Ejemplo n.º 5
0
def test_ani1(ani1_path, ani1_dataset):
    """
    Test if MD17 dataset object has same behaviour as AtomsData.
    """
    atoms_data = spk.AtomsData(ani1_path)
    assert_dataset_equal(atoms_data, ani1_dataset)
Ejemplo n.º 6
0
def test_md17(ethanol_path, md17_dataset):
    """
    Test if MD17 dataset object has same behaviour as AtomsData.
    """
    atoms_data = spk.AtomsData(ethanol_path)
    assert_dataset_equal(atoms_data, md17_dataset)
Ejemplo n.º 7
0
from torch.optim import Adam
import schnetpack as spk
from schnetpack.train import Trainer, CSVHook, ReduceLROnPlateauHook
from schnetpack.metrics import MeanAbsoluteError
from schnetpack.metrics import mse_loss

logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))

# basic settings
model_dir = "ethanol_model"  # directory that will be created for storing model
os.makedirs(model_dir)
properties = ["energy", "forces"]  # properties used for training

# data preparation
logging.info("get dataset")
dataset = spk.AtomsData("data/ethanol.db", required_properties=properties)
train, val, test = spk.train_test_split(
    data=dataset,
    num_train=1000,
    num_val=100,
    split_file=os.path.join(model_dir, "split.npz"),
)
train_loader = spk.AtomsLoader(train, batch_size=64)
val_loader = spk.AtomsLoader(val, batch_size=64)

# get statistics
atomrefs = dataset.get_atomrefs(properties)
per_atom = dict(energy=True, forces=False)
means, stddevs = train_loader.get_statistics(properties,
                                             atomrefs=atomrefs,
                                             per_atom=per_atom)