def schnet( n_atom_basis, n_filters, n_interactions, cutoff, n_gaussians, normalize_filter, coupled_interactions, return_intermediate, max_z, cutoff_network, trainable_gaussians, distance_expansion, charged_systems, ): return spk.SchNet( n_atom_basis=n_atom_basis, n_filters=n_filters, n_interactions=n_interactions, cutoff=cutoff, n_gaussians=n_gaussians, normalize_filter=normalize_filter, coupled_interactions=coupled_interactions, return_intermediate=return_intermediate, max_z=max_z, cutoff_network=cutoff_network, trainable_gaussians=trainable_gaussians, distance_expansion=distance_expansion, charged_systems=charged_systems, )
def test_charge_correction(schnet_batch, n_atom_basis): """ Test if charge correction yields the desired total charges. """ model = spk.AtomisticModel( spk.SchNet(n_atom_basis), spk.atomistic.DipoleMoment( n_atom_basis, charge_correction="q", contributions="q" ), ) q = torch.randint(0, 10, (schnet_batch["_positions"].shape[0], 1)) schnet_batch.update(q=q) q_i = model(schnet_batch)["q"] assert torch.allclose(q.float(), q_i.sum(1), atol=1e-6)
dataset = QM9("data/qm9.db", properties=[QM9.U0]) train, val, test = spk.train_test_split( dataset, 1000, 100, os.path.join(model_dir, "split.npz") ) train_loader = spk.AtomsLoader(train, batch_size=64) val_loader = spk.AtomsLoader(val, batch_size=64) # statistics atomrefs = dataset.get_atomrefs(properties) means, stddevs = train_loader.get_statistics( properties, per_atom=True, atomrefs=atomrefs ) # model build logging.info("build model") representation = spk.SchNet(n_interactions=6) output_modules = [ spk.Atomwise( property=QM9.U0, mean=means[QM9.U0], stddev=stddevs[QM9.U0], atomref=atomrefs[QM9.U0], ) ] model = spk.AtomisticModel(representation, output_modules) # build optimizer optimizer = Adam(model.parameters(), lr=1e-4) # hooks logging.info("build trainer")
# get statistics atomrefs = dataset.get_atomref(properties) per_atom = dict(energy=True, forces=False) means, stddevs = train_loader.get_statistics( properties, single_atom_ref=atomrefs, divide_by_atoms=per_atom ) # model build logging.info("build model") ### Positional Embedding: ###representation = spk.SchNet(n_interactions=4,n_scales=1,n_filters=256, use_log_normal=False, n_gaussians=32,cutoff=10. ) representation = spk.SchNet(n_interactions=4,n_scales=1,n_filters=256, use_log_normal=True, n_gaussians=32,cutoff=10. ) ###representation = spk.SchNet(n_interactions=1,n_scales=4,n_filters=256, use_log_normal=False, n_gaussians=32,cutoff=10. ) ###representation = spk.SchNet(n_interactions=1,n_scales=4,n_filters=256, use_log_normal=True, n_gaussians=32,cutoff=10. ) ### Tansition: representation = spk.TDTNet(n_interactions=2, n_scales=4, n_heads=8, cutoff=10., use_act=True, use_mcr=True, trainable_gaussians=True) ########################### output_modules = [ spk.atomistic.Atomwise( n_in=representation.n_atom_basis, property="energy", derivative="forces",