def test_datamodule_standardize(energy, forces, has_atomref, tmpdir): args = load_example_args("graph-network") args["standardize"] = True args["train_size"] = 800 args["val_size"] = 100 args["test_size"] = 100 args["log_dir"] = tmpdir dataset = DummyDataset(energy=energy, forces=forces, has_atomref=has_atomref) data = DataModule(args, dataset=dataset) data.prepare_data() data.setup("fit") assert (data.atomref is not None) == has_atomref if has_atomref: assert (data.atomref == dataset.get_atomref()).all() if energy: train_energies = torch.tensor(dataset.energies)[data.idx_train] if has_atomref: # the mean and std should be computed after removing atomrefs train_energies -= torch.tensor( [ dataset.atomref[zs].sum() for i, zs in enumerate(dataset.z) if i in data.idx_train ] ) # mean and std attributes should provide mean and std of the training split assert torch.allclose(data.mean, train_energies.mean()) # assert torch.allclose(data.std, train_energies.std()) else: # the data module should not have mean and std set if the dataset does not include energies assert data.mean is None and data.std is None
def test_train(model_name, use_atomref, tmpdir): args = load_example_args( model_name, remove_prior=not use_atomref, train_size=0.8, val_size=0.05, test_size=None, log_dir=tmpdir, derivative=True, embedding_dimension=16, num_layers=2, num_rbf=16, batch_size=8, ) datamodule = DataModule(args, DummyDataset(has_atomref=use_atomref)) prior = None if use_atomref: prior = getattr(priors, args["prior_model"])(dataset=datamodule.dataset) args["prior_args"] = prior.get_init_args() module = LNNP(args, prior_model=prior) trainer = pl.Trainer(max_steps=10, default_root_dir=tmpdir) trainer.fit(module, datamodule) trainer.test(module, datamodule)
def test_forward(model_name, use_batch, explicit_q_s): z, pos, batch = create_example_batch() model = create_model(load_example_args(model_name, prior_model=None)) batch = batch if use_batch else None if explicit_q_s: model(z, pos, batch=batch, q=None, s=None) else: model(z, pos, batch=batch)
def test_seed(model_name): args = load_example_args(model_name, remove_prior=True) pl.seed_everything(1234) m1 = create_model(args) pl.seed_everything(1234) m2 = create_model(args) for p1, p2 in zip(m1.parameters(), m2.parameters()): assert (p1 == p2).all(), "Parameters don't match although using the same seed."
def test_atomref(model_name): dataset = DummyDataset(has_atomref=True) atomref = Atomref(max_z=100, dataset=dataset) z, pos, batch = create_example_batch() # create model with atomref pl.seed_everything(1234) model_atomref = create_model(load_example_args(model_name, prior_model="Atomref"), prior_model=atomref) # create model without atomref pl.seed_everything(1234) model_no_atomref = create_model( load_example_args(model_name, remove_prior=True)) # get output from both models x_atomref, _ = model_atomref(z, pos, batch) x_no_atomref, _ = model_no_atomref(z, pos, batch) # check if the output of both models differs by the expected atomref contribution expected_offset = scatter(dataset.get_atomref().squeeze()[z], batch).unsqueeze(1) torch.testing.assert_allclose(x_atomref, x_no_atomref + expected_offset)
def test_atom_filter(remove_threshold, model_name): # wrap a representation model using the AtomFilter wrapper model = create_model(load_example_args(model_name, remove_prior=True)) model = model.representation_model model = AtomFilter(model, remove_threshold) z, pos, batch = create_example_batch(n_atoms=100) x, v, z, pos, batch = model(z, pos, batch, None, None) assert (z > remove_threshold).all(), ( f"Lowest updated atomic number is {z.min()} but " f"the atom filter is set to {remove_threshold}") assert len(z) == len( pos), "Number of z and pos values doesn't match after AtomFilter" assert len(z) == len( batch), "Number of z and batch values doesn't match after AtomFilter"
def test_scalar_invariance(): torch.manual_seed(1234) rotate = torch.tensor([ [0.9886788, -0.1102370, 0.1017945], [0.1363630, 0.9431761, -0.3030248], [-0.0626055, 0.3134752, 0.9475304], ]) model = create_model( load_example_args("equivariant-transformer", prior_model=None)) z = torch.ones(100, dtype=torch.long) pos = torch.randn(100, 3) batch = torch.arange(50, dtype=torch.long).repeat_interleave(2) y = model(z, pos, batch)[0] y_rot = model(z, pos @ rotate, batch)[0] torch.testing.assert_allclose(y, y_rot)
def test_gated_eq_gradients(): model = create_model( load_example_args( "equivariant-transformer", prior_model=None, cutoff_upper=5, derivative=True ) ) # generate example where one atom is outside the cutoff radius of all others z = torch.tensor([1, 1, 8]) pos = torch.tensor([[0, 0, 0], [0, 1, 0], [10, 0, 0]], dtype=torch.float) _, forces = model(z, pos) # compute gradients of forces with respect to the model's emebdding weights deriv = grad(forces.sum(), model.representation_model.embedding.weight)[0] assert ( not deriv.isnan().any() ), "Encountered NaN gradients while backpropagating the force loss"
def test_datamodule_create(tmpdir): args = load_example_args("graph-network") args["train_size"] = 800 args["val_size"] = 100 args["test_size"] = 100 args["log_dir"] = tmpdir dataset = DummyDataset() data = DataModule(args, dataset=dataset) data.prepare_data() data.setup("fit") data._get_dataloader(data.train_dataset, "train", store_dataloader=False) data._get_dataloader(data.val_dataset, "val", store_dataloader=False) data._get_dataloader(data.test_dataset, "test", store_dataloader=False) dl1 = data._get_dataloader(data.train_dataset, "train", store_dataloader=False) dl2 = data._get_dataloader(data.train_dataset, "train", store_dataloader=False) assert dl1 is not dl2
def test_forward_output(model_name, output_model, overwrite_reference=False): pl.seed_everything(1234) # create model and sample batch derivative = output_model in ["Scalar", "EquivariantScalar"] args = load_example_args( model_name, remove_prior=True, output_model=output_model, derivative=derivative, ) model = create_model(args) z, pos, batch = create_example_batch(n_atoms=5) # run step pred, deriv = model(z, pos, batch) # load reference outputs expected_path = join(dirname(__file__), "expected.pkl") assert exists(expected_path), "Couldn't locate reference outputs." with open(expected_path, "rb") as f: expected = pickle.load(f) if overwrite_reference: # this overwrites the previous reference outputs and shouldn't be executed during testing if model_name in expected: expected[model_name][output_model] = dict(pred=pred, deriv=deriv) else: expected[model_name] = {output_model: dict(pred=pred, deriv=deriv)} with open(expected_path, "wb") as f: pickle.dump(expected, f) assert ( False ), f"Set new reference outputs for {model_name} with output model {output_model}." # compare actual ouput with reference torch.testing.assert_allclose(pred, expected[model_name][output_model]["pred"]) if derivative: torch.testing.assert_allclose( deriv, expected[model_name][output_model]["deriv"] )
def test_forward_torchscript(model_name): z, pos, batch = create_example_batch() model = torch.jit.script( create_model(load_example_args(model_name, remove_prior=True, derivative=True)) ) model(z, pos, batch=batch)
def test_forward_output_modules(model_name, output_model): z, pos, batch = create_example_batch() args = load_example_args(model_name, remove_prior=True, output_model=output_model) model = create_model(args) model(z, pos, batch=batch)
def test_create_model(model_name): LNNP(load_example_args(model_name), prior_model=Atomref(100))