def test_forward(model_name, use_batch, explicit_q_s): z, pos, batch = create_example_batch() model = create_model(load_example_args(model_name, prior_model=None)) batch = batch if use_batch else None if explicit_q_s: model(z, pos, batch=batch, q=None, s=None) else: model(z, pos, batch=batch)
def test_compare_forward_multiple(): checkpoint = join(dirname(dirname(__file__)), "tests", "example.ckpt") z1, pos1, _ = create_example_batch(multiple_batches=False) z2, pos2, _ = create_example_batch(multiple_batches=False) calc = External(checkpoint, torch.stack([z1, z2], dim=0)) model = load_model(checkpoint, derivative=True) e_calc, f_calc = calc.calculate(torch.cat([pos1, pos2], dim=0), None) e_pred, f_pred = model( torch.cat([z1, z2]), torch.cat([pos1, pos2], dim=0), torch.cat([torch.zeros(len(z1)), torch.ones(len(z2))]).long(), ) assert_allclose(e_calc, e_pred) assert_allclose(f_calc, f_pred.view(-1, len(z1), 3))
def test_compare_forward(): checkpoint = join(dirname(dirname(__file__)), "tests", "example.ckpt") z, pos, _ = create_example_batch(multiple_batches=False) calc = External(checkpoint, z.unsqueeze(0)) model = load_model(checkpoint, derivative=True) e_calc, f_calc = calc.calculate(pos, None) e_pred, f_pred = model(z, pos) assert_allclose(e_calc, e_pred) assert_allclose(f_calc, f_pred.unsqueeze(0))
def test_example_yamls(fname): with open(fname, "r") as f: args = yaml.load(f, Loader=yaml.FullLoader) prior = None if args["prior_model"] is not None: dataset = DummyDataset(has_atomref=True) prior = getattr(priors, args["prior_model"])(dataset=dataset) model = create_model(args, prior_model=prior) z, pos, batch = create_example_batch() model(z, pos, batch) model(z, pos, batch, q=None, s=None)
def test_atom_filter(remove_threshold, model_name): # wrap a representation model using the AtomFilter wrapper model = create_model(load_example_args(model_name, remove_prior=True)) model = model.representation_model model = AtomFilter(model, remove_threshold) z, pos, batch = create_example_batch(n_atoms=100) x, v, z, pos, batch = model(z, pos, batch, None, None) assert (z > remove_threshold).all(), ( f"Lowest updated atomic number is {z.min()} but " f"the atom filter is set to {remove_threshold}") assert len(z) == len( pos), "Number of z and pos values doesn't match after AtomFilter" assert len(z) == len( batch), "Number of z and batch values doesn't match after AtomFilter"
def test_forward_output(model_name, output_model, overwrite_reference=False): pl.seed_everything(1234) # create model and sample batch derivative = output_model in ["Scalar", "EquivariantScalar"] args = load_example_args( model_name, remove_prior=True, output_model=output_model, derivative=derivative, ) model = create_model(args) z, pos, batch = create_example_batch(n_atoms=5) # run step pred, deriv = model(z, pos, batch) # load reference outputs expected_path = join(dirname(__file__), "expected.pkl") assert exists(expected_path), "Couldn't locate reference outputs." with open(expected_path, "rb") as f: expected = pickle.load(f) if overwrite_reference: # this overwrites the previous reference outputs and shouldn't be executed during testing if model_name in expected: expected[model_name][output_model] = dict(pred=pred, deriv=deriv) else: expected[model_name] = {output_model: dict(pred=pred, deriv=deriv)} with open(expected_path, "wb") as f: pickle.dump(expected, f) assert ( False ), f"Set new reference outputs for {model_name} with output model {output_model}." # compare actual ouput with reference torch.testing.assert_allclose(pred, expected[model_name][output_model]["pred"]) if derivative: torch.testing.assert_allclose( deriv, expected[model_name][output_model]["deriv"] )
def test_atomref(model_name): dataset = DummyDataset(has_atomref=True) atomref = Atomref(max_z=100, dataset=dataset) z, pos, batch = create_example_batch() # create model with atomref pl.seed_everything(1234) model_atomref = create_model(load_example_args(model_name, prior_model="Atomref"), prior_model=atomref) # create model without atomref pl.seed_everything(1234) model_no_atomref = create_model( load_example_args(model_name, remove_prior=True)) # get output from both models x_atomref, _ = model_atomref(z, pos, batch) x_no_atomref, _ = model_no_atomref(z, pos, batch) # check if the output of both models differs by the expected atomref contribution expected_offset = scatter(dataset.get_atomref().squeeze()[z], batch).unsqueeze(1) torch.testing.assert_allclose(x_atomref, x_no_atomref + expected_offset)
def test_forward_torchscript(model_name): z, pos, batch = create_example_batch() model = torch.jit.script( create_model(load_example_args(model_name, remove_prior=True, derivative=True)) ) model(z, pos, batch=batch)
def test_forward_output_modules(model_name, output_model): z, pos, batch = create_example_batch() args = load_example_args(model_name, remove_prior=True, output_model=output_model) model = create_model(args) model(z, pos, batch=batch)