Example #1
0
def test_seed(model_name):
    args = load_example_args(model_name, remove_prior=True)
    pl.seed_everything(1234)
    m1 = create_model(args)
    pl.seed_everything(1234)
    m2 = create_model(args)

    for p1, p2 in zip(m1.parameters(), m2.parameters()):
        assert (p1 == p2).all(), "Parameters don't match although using the same seed."
Example #2
0
def test_forward(model_name, use_batch, explicit_q_s):
    z, pos, batch = create_example_batch()
    model = create_model(load_example_args(model_name, prior_model=None))
    batch = batch if use_batch else None
    if explicit_q_s:
        model(z, pos, batch=batch, q=None, s=None)
    else:
        model(z, pos, batch=batch)
Example #3
0
def test_atomref(model_name):
    dataset = DummyDataset(has_atomref=True)
    atomref = Atomref(max_z=100, dataset=dataset)
    z, pos, batch = create_example_batch()

    # create model with atomref
    pl.seed_everything(1234)
    model_atomref = create_model(load_example_args(model_name,
                                                   prior_model="Atomref"),
                                 prior_model=atomref)
    # create model without atomref
    pl.seed_everything(1234)
    model_no_atomref = create_model(
        load_example_args(model_name, remove_prior=True))

    # get output from both models
    x_atomref, _ = model_atomref(z, pos, batch)
    x_no_atomref, _ = model_no_atomref(z, pos, batch)

    # check if the output of both models differs by the expected atomref contribution
    expected_offset = scatter(dataset.get_atomref().squeeze()[z],
                              batch).unsqueeze(1)
    torch.testing.assert_allclose(x_atomref, x_no_atomref + expected_offset)
Example #4
0
def test_example_yamls(fname):
    with open(fname, "r") as f:
        args = yaml.load(f, Loader=yaml.FullLoader)

    prior = None
    if args["prior_model"] is not None:
        dataset = DummyDataset(has_atomref=True)
        prior = getattr(priors, args["prior_model"])(dataset=dataset)

    model = create_model(args, prior_model=prior)

    z, pos, batch = create_example_batch()
    model(z, pos, batch)
    model(z, pos, batch, q=None, s=None)
Example #5
0
def test_atom_filter(remove_threshold, model_name):
    # wrap a representation model using the AtomFilter wrapper
    model = create_model(load_example_args(model_name, remove_prior=True))
    model = model.representation_model
    model = AtomFilter(model, remove_threshold)

    z, pos, batch = create_example_batch(n_atoms=100)
    x, v, z, pos, batch = model(z, pos, batch, None, None)

    assert (z > remove_threshold).all(), (
        f"Lowest updated atomic number is {z.min()} but "
        f"the atom filter is set to {remove_threshold}")
    assert len(z) == len(
        pos), "Number of z and pos values doesn't match after AtomFilter"
    assert len(z) == len(
        batch), "Number of z and batch values doesn't match after AtomFilter"
Example #6
0
def test_scalar_invariance():
    torch.manual_seed(1234)
    rotate = torch.tensor([
        [0.9886788, -0.1102370, 0.1017945],
        [0.1363630, 0.9431761, -0.3030248],
        [-0.0626055, 0.3134752, 0.9475304],
    ])

    model = create_model(
        load_example_args("equivariant-transformer", prior_model=None))
    z = torch.ones(100, dtype=torch.long)
    pos = torch.randn(100, 3)
    batch = torch.arange(50, dtype=torch.long).repeat_interleave(2)

    y = model(z, pos, batch)[0]
    y_rot = model(z, pos @ rotate, batch)[0]
    torch.testing.assert_allclose(y, y_rot)
Example #7
0
def test_gn(device, num_atoms):

    if not pt.cuda.is_available() and device == 'cuda':
        pytest.skip('No GPU')

    device = pt.device(device)

    # Generate random inputs
    elements = pt.randint(1, 100, (num_atoms,)).to(device)
    positions = (10 * pt.rand((num_atoms, 3)) - 5).to(device)

    # Crate a non-optimized model
    #   SchNet: TorchMD_GN(rbf_type='gauss', trainable_rbf=False, activation='ssp', neighbor_embedding=False)
    args = {
        'embedding_dimension': 128,
        'num_layers': 6,
        'num_rbf': 50,
        'rbf_type': 'gauss',
        'trainable_rbf': False,
        'activation': 'ssp',
        'neighbor_embedding': False,
        'cutoff_lower': 0.0,
        'cutoff_upper': 5.0,
        'max_z': 100,
        'max_num_neighbors': num_atoms,
        'model': 'graph-network',
        'aggr': 'add',
        'derivative': True,
        'atom_filter': -1,
        'prior_model': None,
        'output_model': 'Scalar',
        'reduce_op': 'add'
    }
    ref_model = create_model(args).to(device)

    # Execute the non-optimized model
    ref_energy, ref_gradient = ref_model(elements, positions)

    # Optimize the model
    model = optimize(ref_model).to(device)

    # Execute the optimize model
    energy, gradient = model(elements, positions)

    assert pt.allclose(ref_energy, energy, atol=5e-7)
    assert pt.allclose(ref_gradient, gradient, atol=1e-5)
Example #8
0
def test_gated_eq_gradients():
    model = create_model(
        load_example_args(
            "equivariant-transformer", prior_model=None, cutoff_upper=5, derivative=True
        )
    )

    # generate example where one atom is outside the cutoff radius of all others
    z = torch.tensor([1, 1, 8])
    pos = torch.tensor([[0, 0, 0], [0, 1, 0], [10, 0, 0]], dtype=torch.float)

    _, forces = model(z, pos)

    # compute gradients of forces with respect to the model's emebdding weights
    deriv = grad(forces.sum(), model.representation_model.embedding.weight)[0]
    assert (
        not deriv.isnan().any()
    ), "Encountered NaN gradients while backpropagating the force loss"
Example #9
0
def test_forward_output(model_name, output_model, overwrite_reference=False):
    pl.seed_everything(1234)

    # create model and sample batch
    derivative = output_model in ["Scalar", "EquivariantScalar"]
    args = load_example_args(
        model_name, remove_prior=True, output_model=output_model, derivative=derivative,
    )
    model = create_model(args)
    z, pos, batch = create_example_batch(n_atoms=5)

    # run step
    pred, deriv = model(z, pos, batch)

    # load reference outputs
    expected_path = join(dirname(__file__), "expected.pkl")
    assert exists(expected_path), "Couldn't locate reference outputs."
    with open(expected_path, "rb") as f:
        expected = pickle.load(f)

    if overwrite_reference:
        # this overwrites the previous reference outputs and shouldn't be executed during testing
        if model_name in expected:
            expected[model_name][output_model] = dict(pred=pred, deriv=deriv)
        else:
            expected[model_name] = {output_model: dict(pred=pred, deriv=deriv)}

        with open(expected_path, "wb") as f:
            pickle.dump(expected, f)
        assert (
            False
        ), f"Set new reference outputs for {model_name} with output model {output_model}."

    # compare actual ouput with reference
    torch.testing.assert_allclose(pred, expected[model_name][output_model]["pred"])
    if derivative:
        torch.testing.assert_allclose(
            deriv, expected[model_name][output_model]["deriv"]
        )
Example #10
0
def test_forward_torchscript(model_name):
    z, pos, batch = create_example_batch()
    model = torch.jit.script(
        create_model(load_example_args(model_name, remove_prior=True, derivative=True))
    )
    model(z, pos, batch=batch)
Example #11
0
def test_forward_output_modules(model_name, output_model):
    z, pos, batch = create_example_batch()
    args = load_example_args(model_name, remove_prior=True, output_model=output_model)
    model = create_model(args)
    model(z, pos, batch=batch)