def test_shape_elemental_gate( elemental_gate_layer, elements, random_int_input, random_shape, ): out_shape = random_shape + [len(elements)] assert_output_shape_valid(elemental_gate_layer, [random_int_input], out_shape)
def test_shape_cfconv( cfconv_layer, random_atomic_env, r_ij, neighbors, neighbor_mask, f_ij, cfconv_output_shape, ): inputs = [random_atomic_env, r_ij, neighbors, neighbor_mask, f_ij] assert_output_shape_valid(cfconv_layer, inputs, cfconv_output_shape)
def test_shape_schnetinteraction( schnet_interaction, random_atomic_env, r_ij, neighbors, neighbor_mask, f_ij, interaction_output_shape, ): inputs = [random_atomic_env, r_ij, neighbors, neighbor_mask, f_ij] assert_output_shape_valid(schnet_interaction, inputs, interaction_output_shape)
def test_shape_scale_shift(random_float_input, random_shape): mean = torch.rand(1) std = torch.rand(1) model = spk.nn.ScaleShift(mean, std) assert_output_shape_valid(model, [random_float_input], random_shape)
def test_shape_dense(dense_layer, random_float_input, random_shape, random_output_dim): out_shape = random_shape[:-1] + [random_output_dim] assert_output_shape_valid(dense_layer, [random_float_input], out_shape)
def test_gaussian_smearing( gaussion_smearing_layer, random_interatomic_distances, gaussian_smearing_shape ): assert_output_shape_valid( gaussion_smearing_layer, [random_interatomic_distances], gaussian_smearing_shape )
def test_shape_schnet(schnet, schnet_batch, schnet_output_shape): assert_output_shape_valid(schnet, [schnet_batch], schnet_output_shape)
def x_test_shape_neighbor_elements(atomic_numbers, neighbors): # ToDo: change Docstring or squeeze() model = spk.nn.NeighborElements() inputs = [atomic_numbers.unsqueeze(-1), neighbors] out_shape = list(neighbors.shape) assert_output_shape_valid(model, inputs, out_shape)
def test_shape_cutoff(cutoff_layer, random_interatomic_distances): out_shape = list(random_interatomic_distances.shape) assert_output_shape_valid(cutoff_layer, [random_interatomic_distances], out_shape)
def test_shape_tiled_multilayer_network( tiled_mlp_layer, n_mlp_tiles, random_float_input, random_shape, random_output_dim ): out_shape = random_shape[:-1] + [random_output_dim * n_mlp_tiles] assert_output_shape_valid(tiled_mlp_layer, [random_float_input], out_shape)
def test_shape_aggregate(): model = spk.nn.Aggregate(axis=1) input_data = torch.rand((3, 4, 5)) inputs = [input_data] out_shape = [3, 5] assert_output_shape_valid(model, inputs, out_shape)
def test_shape_standardize(random_float_input, random_shape): mean = torch.rand(1) std = torch.rand(1) model = spk.nn.Standardize(mean, std) assert_output_shape_valid(model, [random_float_input], random_shape)