def test_seed(model_name): args = load_example_args(model_name, remove_prior=True) pl.seed_everything(1234) m1 = create_model(args) pl.seed_everything(1234) m2 = create_model(args) for p1, p2 in zip(m1.parameters(), m2.parameters()): assert (p1 == p2).all(), "Parameters don't match although using the same seed."
def test_forward(model_name, use_batch, explicit_q_s): z, pos, batch = create_example_batch() model = create_model(load_example_args(model_name, prior_model=None)) batch = batch if use_batch else None if explicit_q_s: model(z, pos, batch=batch, q=None, s=None) else: model(z, pos, batch=batch)
def test_atomref(model_name): dataset = DummyDataset(has_atomref=True) atomref = Atomref(max_z=100, dataset=dataset) z, pos, batch = create_example_batch() # create model with atomref pl.seed_everything(1234) model_atomref = create_model(load_example_args(model_name, prior_model="Atomref"), prior_model=atomref) # create model without atomref pl.seed_everything(1234) model_no_atomref = create_model( load_example_args(model_name, remove_prior=True)) # get output from both models x_atomref, _ = model_atomref(z, pos, batch) x_no_atomref, _ = model_no_atomref(z, pos, batch) # check if the output of both models differs by the expected atomref contribution expected_offset = scatter(dataset.get_atomref().squeeze()[z], batch).unsqueeze(1) torch.testing.assert_allclose(x_atomref, x_no_atomref + expected_offset)
def test_example_yamls(fname): with open(fname, "r") as f: args = yaml.load(f, Loader=yaml.FullLoader) prior = None if args["prior_model"] is not None: dataset = DummyDataset(has_atomref=True) prior = getattr(priors, args["prior_model"])(dataset=dataset) model = create_model(args, prior_model=prior) z, pos, batch = create_example_batch() model(z, pos, batch) model(z, pos, batch, q=None, s=None)
def test_atom_filter(remove_threshold, model_name): # wrap a representation model using the AtomFilter wrapper model = create_model(load_example_args(model_name, remove_prior=True)) model = model.representation_model model = AtomFilter(model, remove_threshold) z, pos, batch = create_example_batch(n_atoms=100) x, v, z, pos, batch = model(z, pos, batch, None, None) assert (z > remove_threshold).all(), ( f"Lowest updated atomic number is {z.min()} but " f"the atom filter is set to {remove_threshold}") assert len(z) == len( pos), "Number of z and pos values doesn't match after AtomFilter" assert len(z) == len( batch), "Number of z and batch values doesn't match after AtomFilter"
def test_scalar_invariance(): torch.manual_seed(1234) rotate = torch.tensor([ [0.9886788, -0.1102370, 0.1017945], [0.1363630, 0.9431761, -0.3030248], [-0.0626055, 0.3134752, 0.9475304], ]) model = create_model( load_example_args("equivariant-transformer", prior_model=None)) z = torch.ones(100, dtype=torch.long) pos = torch.randn(100, 3) batch = torch.arange(50, dtype=torch.long).repeat_interleave(2) y = model(z, pos, batch)[0] y_rot = model(z, pos @ rotate, batch)[0] torch.testing.assert_allclose(y, y_rot)
def test_gn(device, num_atoms): if not pt.cuda.is_available() and device == 'cuda': pytest.skip('No GPU') device = pt.device(device) # Generate random inputs elements = pt.randint(1, 100, (num_atoms,)).to(device) positions = (10 * pt.rand((num_atoms, 3)) - 5).to(device) # Crate a non-optimized model # SchNet: TorchMD_GN(rbf_type='gauss', trainable_rbf=False, activation='ssp', neighbor_embedding=False) args = { 'embedding_dimension': 128, 'num_layers': 6, 'num_rbf': 50, 'rbf_type': 'gauss', 'trainable_rbf': False, 'activation': 'ssp', 'neighbor_embedding': False, 'cutoff_lower': 0.0, 'cutoff_upper': 5.0, 'max_z': 100, 'max_num_neighbors': num_atoms, 'model': 'graph-network', 'aggr': 'add', 'derivative': True, 'atom_filter': -1, 'prior_model': None, 'output_model': 'Scalar', 'reduce_op': 'add' } ref_model = create_model(args).to(device) # Execute the non-optimized model ref_energy, ref_gradient = ref_model(elements, positions) # Optimize the model model = optimize(ref_model).to(device) # Execute the optimize model energy, gradient = model(elements, positions) assert pt.allclose(ref_energy, energy, atol=5e-7) assert pt.allclose(ref_gradient, gradient, atol=1e-5)
def test_gated_eq_gradients(): model = create_model( load_example_args( "equivariant-transformer", prior_model=None, cutoff_upper=5, derivative=True ) ) # generate example where one atom is outside the cutoff radius of all others z = torch.tensor([1, 1, 8]) pos = torch.tensor([[0, 0, 0], [0, 1, 0], [10, 0, 0]], dtype=torch.float) _, forces = model(z, pos) # compute gradients of forces with respect to the model's emebdding weights deriv = grad(forces.sum(), model.representation_model.embedding.weight)[0] assert ( not deriv.isnan().any() ), "Encountered NaN gradients while backpropagating the force loss"
def test_forward_output(model_name, output_model, overwrite_reference=False): pl.seed_everything(1234) # create model and sample batch derivative = output_model in ["Scalar", "EquivariantScalar"] args = load_example_args( model_name, remove_prior=True, output_model=output_model, derivative=derivative, ) model = create_model(args) z, pos, batch = create_example_batch(n_atoms=5) # run step pred, deriv = model(z, pos, batch) # load reference outputs expected_path = join(dirname(__file__), "expected.pkl") assert exists(expected_path), "Couldn't locate reference outputs." with open(expected_path, "rb") as f: expected = pickle.load(f) if overwrite_reference: # this overwrites the previous reference outputs and shouldn't be executed during testing if model_name in expected: expected[model_name][output_model] = dict(pred=pred, deriv=deriv) else: expected[model_name] = {output_model: dict(pred=pred, deriv=deriv)} with open(expected_path, "wb") as f: pickle.dump(expected, f) assert ( False ), f"Set new reference outputs for {model_name} with output model {output_model}." # compare actual ouput with reference torch.testing.assert_allclose(pred, expected[model_name][output_model]["pred"]) if derivative: torch.testing.assert_allclose( deriv, expected[model_name][output_model]["deriv"] )
def test_forward_torchscript(model_name): z, pos, batch = create_example_batch() model = torch.jit.script( create_model(load_example_args(model_name, remove_prior=True, derivative=True)) ) model(z, pos, batch=batch)
def test_forward_output_modules(model_name, output_model): z, pos, batch = create_example_batch() args = load_example_args(model_name, remove_prior=True, output_model=output_model) model = create_model(args) model(z, pos, batch=batch)