def test_datamodule_standardize(energy, forces, has_atomref, tmpdir):
    args = load_example_args("graph-network")
    args["standardize"] = True
    args["train_size"] = 800
    args["val_size"] = 100
    args["test_size"] = 100
    args["log_dir"] = tmpdir

    dataset = DummyDataset(energy=energy, forces=forces, has_atomref=has_atomref)
    data = DataModule(args, dataset=dataset)
    data.prepare_data()
    data.setup("fit")

    assert (data.atomref is not None) == has_atomref
    if has_atomref:
        assert (data.atomref == dataset.get_atomref()).all()

    if energy:
        train_energies = torch.tensor(dataset.energies)[data.idx_train]
        if has_atomref:
            # the mean and std should be computed after removing atomrefs
            train_energies -= torch.tensor(
                [
                    dataset.atomref[zs].sum()
                    for i, zs in enumerate(dataset.z)
                    if i in data.idx_train
                ]
            )
        # mean and std attributes should provide mean and std of the training split
        assert torch.allclose(data.mean, train_energies.mean())
        # assert torch.allclose(data.std, train_energies.std())
    else:
        # the data module should not have mean and std set if the dataset does not include energies
        assert data.mean is None and data.std is None
Exemple #2
0
def test_train(model_name, use_atomref, tmpdir):
    args = load_example_args(
        model_name,
        remove_prior=not use_atomref,
        train_size=0.8,
        val_size=0.05,
        test_size=None,
        log_dir=tmpdir,
        derivative=True,
        embedding_dimension=16,
        num_layers=2,
        num_rbf=16,
        batch_size=8,
    )
    datamodule = DataModule(args, DummyDataset(has_atomref=use_atomref))

    prior = None
    if use_atomref:
        prior = getattr(priors,
                        args["prior_model"])(dataset=datamodule.dataset)
        args["prior_args"] = prior.get_init_args()

    module = LNNP(args, prior_model=prior)

    trainer = pl.Trainer(max_steps=10, default_root_dir=tmpdir)
    trainer.fit(module, datamodule)
    trainer.test(module, datamodule)
Exemple #3
0
def test_forward(model_name, use_batch, explicit_q_s):
    z, pos, batch = create_example_batch()
    model = create_model(load_example_args(model_name, prior_model=None))
    batch = batch if use_batch else None
    if explicit_q_s:
        model(z, pos, batch=batch, q=None, s=None)
    else:
        model(z, pos, batch=batch)
Exemple #4
0
def test_seed(model_name):
    args = load_example_args(model_name, remove_prior=True)
    pl.seed_everything(1234)
    m1 = create_model(args)
    pl.seed_everything(1234)
    m2 = create_model(args)

    for p1, p2 in zip(m1.parameters(), m2.parameters()):
        assert (p1 == p2).all(), "Parameters don't match although using the same seed."
Exemple #5
0
def test_atomref(model_name):
    dataset = DummyDataset(has_atomref=True)
    atomref = Atomref(max_z=100, dataset=dataset)
    z, pos, batch = create_example_batch()

    # create model with atomref
    pl.seed_everything(1234)
    model_atomref = create_model(load_example_args(model_name,
                                                   prior_model="Atomref"),
                                 prior_model=atomref)
    # create model without atomref
    pl.seed_everything(1234)
    model_no_atomref = create_model(
        load_example_args(model_name, remove_prior=True))

    # get output from both models
    x_atomref, _ = model_atomref(z, pos, batch)
    x_no_atomref, _ = model_no_atomref(z, pos, batch)

    # check if the output of both models differs by the expected atomref contribution
    expected_offset = scatter(dataset.get_atomref().squeeze()[z],
                              batch).unsqueeze(1)
    torch.testing.assert_allclose(x_atomref, x_no_atomref + expected_offset)
Exemple #6
0
def test_atom_filter(remove_threshold, model_name):
    # wrap a representation model using the AtomFilter wrapper
    model = create_model(load_example_args(model_name, remove_prior=True))
    model = model.representation_model
    model = AtomFilter(model, remove_threshold)

    z, pos, batch = create_example_batch(n_atoms=100)
    x, v, z, pos, batch = model(z, pos, batch, None, None)

    assert (z > remove_threshold).all(), (
        f"Lowest updated atomic number is {z.min()} but "
        f"the atom filter is set to {remove_threshold}")
    assert len(z) == len(
        pos), "Number of z and pos values doesn't match after AtomFilter"
    assert len(z) == len(
        batch), "Number of z and batch values doesn't match after AtomFilter"
def test_scalar_invariance():
    torch.manual_seed(1234)
    rotate = torch.tensor([
        [0.9886788, -0.1102370, 0.1017945],
        [0.1363630, 0.9431761, -0.3030248],
        [-0.0626055, 0.3134752, 0.9475304],
    ])

    model = create_model(
        load_example_args("equivariant-transformer", prior_model=None))
    z = torch.ones(100, dtype=torch.long)
    pos = torch.randn(100, 3)
    batch = torch.arange(50, dtype=torch.long).repeat_interleave(2)

    y = model(z, pos, batch)[0]
    y_rot = model(z, pos @ rotate, batch)[0]
    torch.testing.assert_allclose(y, y_rot)
def test_gated_eq_gradients():
    model = create_model(
        load_example_args(
            "equivariant-transformer", prior_model=None, cutoff_upper=5, derivative=True
        )
    )

    # generate example where one atom is outside the cutoff radius of all others
    z = torch.tensor([1, 1, 8])
    pos = torch.tensor([[0, 0, 0], [0, 1, 0], [10, 0, 0]], dtype=torch.float)

    _, forces = model(z, pos)

    # compute gradients of forces with respect to the model's emebdding weights
    deriv = grad(forces.sum(), model.representation_model.embedding.weight)[0]
    assert (
        not deriv.isnan().any()
    ), "Encountered NaN gradients while backpropagating the force loss"
def test_datamodule_create(tmpdir):
    args = load_example_args("graph-network")
    args["train_size"] = 800
    args["val_size"] = 100
    args["test_size"] = 100
    args["log_dir"] = tmpdir

    dataset = DummyDataset()
    data = DataModule(args, dataset=dataset)
    data.prepare_data()
    data.setup("fit")

    data._get_dataloader(data.train_dataset, "train", store_dataloader=False)
    data._get_dataloader(data.val_dataset, "val", store_dataloader=False)
    data._get_dataloader(data.test_dataset, "test", store_dataloader=False)

    dl1 = data._get_dataloader(data.train_dataset, "train", store_dataloader=False)
    dl2 = data._get_dataloader(data.train_dataset, "train", store_dataloader=False)
    assert dl1 is not dl2
Exemple #10
0
def test_forward_output(model_name, output_model, overwrite_reference=False):
    pl.seed_everything(1234)

    # create model and sample batch
    derivative = output_model in ["Scalar", "EquivariantScalar"]
    args = load_example_args(
        model_name, remove_prior=True, output_model=output_model, derivative=derivative,
    )
    model = create_model(args)
    z, pos, batch = create_example_batch(n_atoms=5)

    # run step
    pred, deriv = model(z, pos, batch)

    # load reference outputs
    expected_path = join(dirname(__file__), "expected.pkl")
    assert exists(expected_path), "Couldn't locate reference outputs."
    with open(expected_path, "rb") as f:
        expected = pickle.load(f)

    if overwrite_reference:
        # this overwrites the previous reference outputs and shouldn't be executed during testing
        if model_name in expected:
            expected[model_name][output_model] = dict(pred=pred, deriv=deriv)
        else:
            expected[model_name] = {output_model: dict(pred=pred, deriv=deriv)}

        with open(expected_path, "wb") as f:
            pickle.dump(expected, f)
        assert (
            False
        ), f"Set new reference outputs for {model_name} with output model {output_model}."

    # compare actual ouput with reference
    torch.testing.assert_allclose(pred, expected[model_name][output_model]["pred"])
    if derivative:
        torch.testing.assert_allclose(
            deriv, expected[model_name][output_model]["deriv"]
        )
Exemple #11
0
def test_forward_torchscript(model_name):
    z, pos, batch = create_example_batch()
    model = torch.jit.script(
        create_model(load_example_args(model_name, remove_prior=True, derivative=True))
    )
    model(z, pos, batch=batch)
Exemple #12
0
def test_forward_output_modules(model_name, output_model):
    z, pos, batch = create_example_batch()
    args = load_example_args(model_name, remove_prior=True, output_model=output_model)
    model = create_model(args)
    model(z, pos, batch=batch)
Exemple #13
0
def test_create_model(model_name):
    LNNP(load_example_args(model_name), prior_model=Atomref(100))