def test_qm9(qm9_path, qm9_dataset): """ Test if QM9 dataset object has same behaviour as AtomsData. """ atoms_data = spk.AtomsData(qm9_path) assert_dataset_equal(atoms_data, qm9_dataset)
def get_dataset(args, train_args, environment_provider): load_only = [train_args.property] try: if train_args.derivative is not None: load_only.append(train_args.derivative) except: pass dataset = spk.AtomsData( args.dbpath, load_only=load_only, collect_triples=train_args.model == "wacsf", environment_provider=environment_provider, ) return dataset
def get_dataset(args, environment_provider, logging=None): """ Get dataset from arguments. Args: args (argparse.Namespace): parsed arguments environment_provider (spk.environment.BaseEnvironmentProvider): environment- provider of dataset logging: logger Returns: spk.data.AtomsData: dataset """ if args.dataset == "qm9": if logging: logging.info("QM9 will be loaded...") qm9 = spk.datasets.QM9( args.datapath, download=True, load_only=[args.property], collect_triples=args.model == "wacsf", remove_uncharacterized=args.remove_uncharacterized, environment_provider=environment_provider, ) return qm9 elif args.dataset == "ani1": if logging: logging.info("ANI1 will be loaded...") ani1 = spk.datasets.ANI1( args.datapath, download=True, load_only=[args.property], collect_triples=args.model == "wacsf", num_heavy_atoms=args.num_heavy_atoms, environment_provider=environment_provider, ) return ani1 elif args.dataset == "md17": if logging: logging.info("MD17 will be loaded...") md17 = spk.datasets.MD17( args.datapath, args.molecule, download=True, collect_triples=args.model == "wacsf", environment_provider=environment_provider, ) return md17 elif args.dataset == "matproj": if logging: logging.info("Materials project will be loaded...") mp = spk.datasets.MaterialsProject( args.datapath, apikey=args.apikey, download=True, load_only=[args.property], environment_provider=environment_provider, ) if args.timestamp: mp = mp.at_timestamp(args.timestamp) return mp elif args.dataset == "omdb": if logging: logging.info("Organic Materials Database will be loaded...") omdb = spk.datasets.OrganicMaterialsDatabase( args.datapath, download=True, load_only=[args.property], environment_provider=environment_provider, ) return omdb elif args.dataset == "custom": if logging: logging.info("Custom dataset will be loaded...") # define properties to be loaded load_only = [args.property] if args.derivative is not None: load_only.append(args.derivative) dataset = spk.AtomsData( args.datapath, load_only=load_only, collect_triples=args.model == "wacsf", environment_provider=environment_provider, ) return dataset else: raise spk.utils.ScriptError("Invalid dataset selected!")
import schnetpack as spk import schnetpack.atomistic.model from schnetpack.train import Trainer, CSVHook, ReduceLROnPlateauHook from schnetpack.train.metrics import MeanAbsoluteError from schnetpack.train.metrics import mse_loss logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) # basic settings model_dir = "ethanol_model" # directory that will be created for storing model os.makedirs(model_dir) properties = ["energy", "forces"] # properties used for training # data preparation logging.info("get dataset") dataset = spk.AtomsData("data/ethanol.db", load_only=properties) train, val, test = spk.train_test_split( data=dataset, num_train=1000, num_val=100, split_file=os.path.join(model_dir, "split.npz"), ) train_loader = spk.AtomsLoader(train, batch_size=64) val_loader = spk.AtomsLoader(val, batch_size=64) # get statistics atomrefs = dataset.get_atomrefs(properties) per_atom = dict(energy=True, forces=False) means, stddevs = train_loader.get_statistics(properties, single_atom_ref=atomrefs, get_atomwise_statistics=per_atom)
def test_ani1(ani1_path, ani1_dataset): """ Test if MD17 dataset object has same behaviour as AtomsData. """ atoms_data = spk.AtomsData(ani1_path) assert_dataset_equal(atoms_data, ani1_dataset)
def test_md17(ethanol_path, md17_dataset): """ Test if MD17 dataset object has same behaviour as AtomsData. """ atoms_data = spk.AtomsData(ethanol_path) assert_dataset_equal(atoms_data, md17_dataset)
from torch.optim import Adam import schnetpack as spk from schnetpack.train import Trainer, CSVHook, ReduceLROnPlateauHook from schnetpack.metrics import MeanAbsoluteError from schnetpack.metrics import mse_loss logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) # basic settings model_dir = "ethanol_model" # directory that will be created for storing model os.makedirs(model_dir) properties = ["energy", "forces"] # properties used for training # data preparation logging.info("get dataset") dataset = spk.AtomsData("data/ethanol.db", required_properties=properties) train, val, test = spk.train_test_split( data=dataset, num_train=1000, num_val=100, split_file=os.path.join(model_dir, "split.npz"), ) train_loader = spk.AtomsLoader(train, batch_size=64) val_loader = spk.AtomsLoader(val, batch_size=64) # get statistics atomrefs = dataset.get_atomrefs(properties) per_atom = dict(energy=True, forces=False) means, stddevs = train_loader.get_statistics(properties, atomrefs=atomrefs, per_atom=per_atom)