Beispiel #1
0
def test_shape_schnet_with_cutoff(schnet_batch, batchsize, n_atoms, n_atom_basis):
    schnet_batch = [schnet_batch]
    model_cosine = SchNet(n_atom_basis=n_atom_basis, cutoff_network=CosineCutoff)
    model_mollifier = SchNet(n_atom_basis=n_atom_basis, cutoff_network=MollifierCutoff)

    assert_equal_shape(model_cosine, schnet_batch, [batchsize, n_atoms, n_atom_basis])
    assert_equal_shape(model_mollifier, schnet_batch, [batchsize, n_atoms, n_atom_basis])
Beispiel #2
0
def test_parameter_update_schnet_with_cutoff(schnet_batch, n_atom_basis):
    model_cosine = SchNet(n_atom_basis, cutoff_network=CosineCutoff)
    model_mollifier = SchNet(n_atom_basis, cutoff_network=MollifierCutoff)
    schnet_batch = [schnet_batch]
    exclude = ['distance_expansion', 'interactions.0.cutoff_network',
               'interactions.0.cfconv.cutoff_network']

    assert_params_changed(model_cosine, schnet_batch, exclude=exclude)
    assert_params_changed(model_mollifier, schnet_batch, exclude=exclude)
Beispiel #3
0
def test_gaussian_smearing_is_trainable(schnet_batch):
    model = SchNet(trainable_gaussians=True)
    schnet_batch = [schnet_batch]
    assert_params_changed(model,
                          schnet_batch,
                          exclude=[
                              'interactions.0.cutoff_network',
                              'interactions.0.cfconv.cutoff_network'
                          ])
Beispiel #4
0
def test_parameter_update_schnet(schnet_batch):
    model = SchNet()
    schnet_batch = [schnet_batch]
    assert_params_changed(model,
                          schnet_batch,
                          exclude=[
                              'distance_expansion',
                              'interactions.0.cutoff_network',
                              'interactions.0.cfconv.cutoff_network'
                          ])
Beispiel #5
0
def test_parameter_update_schnet(schnet_batch, n_interactions):
    model = SchNet(n_interactions=n_interactions)
    schnet_batch = [schnet_batch]
    assert_params_changed(
        model,
        schnet_batch,
        exclude=[
            "distance_expansion",
            "interactions.0.cutoff_network",
            "interactions.0.cfconv.cutoff_network",
        ],
    )
Beispiel #6
0
def build_schnet(
    return_intermediate,
    n_atom_basis,
    n_filters,
    n_interactions,
    cutoff,
    n_gaussians,
    normalize_filter,
    coupled_interactions,
    max_z,
):
    """
    Build and return SchNet object.

    Args:
        return_intermediate (bool): if true, also return intermediate feature
            representations after each interaction block
        n_atom_basis (int): number of features used to describe atomic
            environments
        n_filters (int): number of filters used in continuous-filter convolution
        n_interactions (int): number of interaction blocks
        cutoff (float): cutoff radius of filters
        n_gaussians (int): number of Gaussians which are used to expand atom
            distances
        normalize_filter (bool): if true, divide filter by number of neighbors
            over which convolution is applied
        coupled_interactions (bool): if true, share the weights across
            interaction blocks and filter-generating networks.
        max_z (int): maximum allowed nuclear charge in dataset. This determines
            the size of the embedding matrix.

    Returns:
        SchNet object
    """

    cutoff_function = get_cutoff()
    return SchNet(
        n_atom_basis=n_atom_basis,
        n_filters=n_filters,
        n_interactions=n_interactions,
        cutoff=cutoff,
        n_gaussians=n_gaussians,
        normalize_filter=normalize_filter,
        coupled_interactions=coupled_interactions,
        return_intermediate=return_intermediate,
        max_z=max_z,
        cutoff_network=cutoff_function,
        charged_systems=False,
    )
Beispiel #7
0
def test_shape_schnet(schnet_batch, batchsize, n_atoms, n_atom_basis):
    schnet_batch = [schnet_batch]
    model = SchNet(n_atom_basis=n_atom_basis)

    assert_equal_shape(model, schnet_batch, [batchsize, n_atoms, n_atom_basis])