Example #1
0
def test_forward(model_name, use_batch, explicit_q_s):
    z, pos, batch = create_example_batch()
    model = create_model(load_example_args(model_name, prior_model=None))
    batch = batch if use_batch else None
    if explicit_q_s:
        model(z, pos, batch=batch, q=None, s=None)
    else:
        model(z, pos, batch=batch)
Example #2
0
def test_compare_forward_multiple():
    checkpoint = join(dirname(dirname(__file__)), "tests", "example.ckpt")
    z1, pos1, _ = create_example_batch(multiple_batches=False)
    z2, pos2, _ = create_example_batch(multiple_batches=False)
    calc = External(checkpoint, torch.stack([z1, z2], dim=0))
    model = load_model(checkpoint, derivative=True)

    e_calc, f_calc = calc.calculate(torch.cat([pos1, pos2], dim=0), None)
    e_pred, f_pred = model(
        torch.cat([z1, z2]),
        torch.cat([pos1, pos2], dim=0),
        torch.cat([torch.zeros(len(z1)),
                   torch.ones(len(z2))]).long(),
    )

    assert_allclose(e_calc, e_pred)
    assert_allclose(f_calc, f_pred.view(-1, len(z1), 3))
Example #3
0
def test_compare_forward():
    checkpoint = join(dirname(dirname(__file__)), "tests", "example.ckpt")
    z, pos, _ = create_example_batch(multiple_batches=False)
    calc = External(checkpoint, z.unsqueeze(0))
    model = load_model(checkpoint, derivative=True)

    e_calc, f_calc = calc.calculate(pos, None)
    e_pred, f_pred = model(z, pos)

    assert_allclose(e_calc, e_pred)
    assert_allclose(f_calc, f_pred.unsqueeze(0))
Example #4
0
def test_example_yamls(fname):
    with open(fname, "r") as f:
        args = yaml.load(f, Loader=yaml.FullLoader)

    prior = None
    if args["prior_model"] is not None:
        dataset = DummyDataset(has_atomref=True)
        prior = getattr(priors, args["prior_model"])(dataset=dataset)

    model = create_model(args, prior_model=prior)

    z, pos, batch = create_example_batch()
    model(z, pos, batch)
    model(z, pos, batch, q=None, s=None)
Example #5
0
def test_atom_filter(remove_threshold, model_name):
    # wrap a representation model using the AtomFilter wrapper
    model = create_model(load_example_args(model_name, remove_prior=True))
    model = model.representation_model
    model = AtomFilter(model, remove_threshold)

    z, pos, batch = create_example_batch(n_atoms=100)
    x, v, z, pos, batch = model(z, pos, batch, None, None)

    assert (z > remove_threshold).all(), (
        f"Lowest updated atomic number is {z.min()} but "
        f"the atom filter is set to {remove_threshold}")
    assert len(z) == len(
        pos), "Number of z and pos values doesn't match after AtomFilter"
    assert len(z) == len(
        batch), "Number of z and batch values doesn't match after AtomFilter"
Example #6
0
def test_forward_output(model_name, output_model, overwrite_reference=False):
    pl.seed_everything(1234)

    # create model and sample batch
    derivative = output_model in ["Scalar", "EquivariantScalar"]
    args = load_example_args(
        model_name, remove_prior=True, output_model=output_model, derivative=derivative,
    )
    model = create_model(args)
    z, pos, batch = create_example_batch(n_atoms=5)

    # run step
    pred, deriv = model(z, pos, batch)

    # load reference outputs
    expected_path = join(dirname(__file__), "expected.pkl")
    assert exists(expected_path), "Couldn't locate reference outputs."
    with open(expected_path, "rb") as f:
        expected = pickle.load(f)

    if overwrite_reference:
        # this overwrites the previous reference outputs and shouldn't be executed during testing
        if model_name in expected:
            expected[model_name][output_model] = dict(pred=pred, deriv=deriv)
        else:
            expected[model_name] = {output_model: dict(pred=pred, deriv=deriv)}

        with open(expected_path, "wb") as f:
            pickle.dump(expected, f)
        assert (
            False
        ), f"Set new reference outputs for {model_name} with output model {output_model}."

    # compare actual ouput with reference
    torch.testing.assert_allclose(pred, expected[model_name][output_model]["pred"])
    if derivative:
        torch.testing.assert_allclose(
            deriv, expected[model_name][output_model]["deriv"]
        )
Example #7
0
def test_atomref(model_name):
    dataset = DummyDataset(has_atomref=True)
    atomref = Atomref(max_z=100, dataset=dataset)
    z, pos, batch = create_example_batch()

    # create model with atomref
    pl.seed_everything(1234)
    model_atomref = create_model(load_example_args(model_name,
                                                   prior_model="Atomref"),
                                 prior_model=atomref)
    # create model without atomref
    pl.seed_everything(1234)
    model_no_atomref = create_model(
        load_example_args(model_name, remove_prior=True))

    # get output from both models
    x_atomref, _ = model_atomref(z, pos, batch)
    x_no_atomref, _ = model_no_atomref(z, pos, batch)

    # check if the output of both models differs by the expected atomref contribution
    expected_offset = scatter(dataset.get_atomref().squeeze()[z],
                              batch).unsqueeze(1)
    torch.testing.assert_allclose(x_atomref, x_no_atomref + expected_offset)
Example #8
0
def test_forward_torchscript(model_name):
    z, pos, batch = create_example_batch()
    model = torch.jit.script(
        create_model(load_example_args(model_name, remove_prior=True, derivative=True))
    )
    model(z, pos, batch=batch)
Example #9
0
def test_forward_output_modules(model_name, output_model):
    z, pos, batch = create_example_batch()
    args = load_example_args(model_name, remove_prior=True, output_model=output_model)
    model = create_model(args)
    model(z, pos, batch=batch)