Beispiel #1
0
def test_shape_elemental_gate(
    elemental_gate_layer,
    elements,
    random_int_input,
    random_shape,
):
    out_shape = random_shape + [len(elements)]
    assert_output_shape_valid(elemental_gate_layer, [random_int_input], out_shape)
Beispiel #2
0
def test_shape_cfconv(
    cfconv_layer,
    random_atomic_env,
    r_ij,
    neighbors,
    neighbor_mask,
    f_ij,
    cfconv_output_shape,
):
    inputs = [random_atomic_env, r_ij, neighbors, neighbor_mask, f_ij]
    assert_output_shape_valid(cfconv_layer, inputs, cfconv_output_shape)
Beispiel #3
0
def test_shape_schnetinteraction(
    schnet_interaction,
    random_atomic_env,
    r_ij,
    neighbors,
    neighbor_mask,
    f_ij,
    interaction_output_shape,
):
    inputs = [random_atomic_env, r_ij, neighbors, neighbor_mask, f_ij]
    assert_output_shape_valid(schnet_interaction, inputs, interaction_output_shape)
Beispiel #4
0
def test_shape_scale_shift(random_float_input, random_shape):
    mean = torch.rand(1)
    std = torch.rand(1)
    model = spk.nn.ScaleShift(mean, std)

    assert_output_shape_valid(model, [random_float_input], random_shape)
Beispiel #5
0
def test_shape_dense(dense_layer, random_float_input, random_shape, random_output_dim):
    out_shape = random_shape[:-1] + [random_output_dim]
    assert_output_shape_valid(dense_layer, [random_float_input], out_shape)
Beispiel #6
0
def test_gaussian_smearing(
    gaussion_smearing_layer, random_interatomic_distances, gaussian_smearing_shape
):
    assert_output_shape_valid(
        gaussion_smearing_layer, [random_interatomic_distances], gaussian_smearing_shape
    )
Beispiel #7
0
def test_shape_schnet(schnet, schnet_batch, schnet_output_shape):
    assert_output_shape_valid(schnet, [schnet_batch], schnet_output_shape)
Beispiel #8
0
def x_test_shape_neighbor_elements(atomic_numbers, neighbors):
    # ToDo: change Docstring or squeeze()
    model = spk.nn.NeighborElements()
    inputs = [atomic_numbers.unsqueeze(-1), neighbors]
    out_shape = list(neighbors.shape)
    assert_output_shape_valid(model, inputs, out_shape)
Beispiel #9
0
def test_shape_cutoff(cutoff_layer, random_interatomic_distances):
    out_shape = list(random_interatomic_distances.shape)
    assert_output_shape_valid(cutoff_layer, [random_interatomic_distances], out_shape)
Beispiel #10
0
def test_shape_tiled_multilayer_network(
    tiled_mlp_layer, n_mlp_tiles, random_float_input, random_shape, random_output_dim
):
    out_shape = random_shape[:-1] + [random_output_dim * n_mlp_tiles]
    assert_output_shape_valid(tiled_mlp_layer, [random_float_input], out_shape)
Beispiel #11
0
def test_shape_aggregate():
    model = spk.nn.Aggregate(axis=1)
    input_data = torch.rand((3, 4, 5))
    inputs = [input_data]
    out_shape = [3, 5]
    assert_output_shape_valid(model, inputs, out_shape)
Beispiel #12
0
def test_shape_standardize(random_float_input, random_shape):
    mean = torch.rand(1)
    std = torch.rand(1)
    model = spk.nn.Standardize(mean, std)

    assert_output_shape_valid(model, [random_float_input], random_shape)