def test_batch_normalized_mlp_allocation(): """Test that BatchNormalizedMLP performs allocation correctly.""" mlp = BatchNormalizedMLP([Tanh(), Tanh()], [5, 7, 9]) mlp.allocate() assert mlp.activations[0].children[0].input_dim == 7 assert mlp.activations[1].children[0].input_dim == 9 assert not any(l.use_bias for l in mlp.linear_transformations)
def test_batch_normalized_mlp_learn_scale_propagated_at_alloc(): """Test that setting learn_scale on a BatchNormalizedMLP works.""" mlp = BatchNormalizedMLP([Tanh(), Tanh()], [5, 7, 9], learn_scale=False) assert not mlp.learn_scale assert all(act.children[0].learn_scale for act in mlp.activations) mlp.allocate() assert not any(act.children[0].learn_scale for act in mlp.activations)
def test_batch_normalized_mlp_mean_only_propagated_at_alloc(): """Test that setting mean_only on a BatchNormalizedMLP works.""" mlp = BatchNormalizedMLP([Tanh(), Tanh()], [5, 7, 9], mean_only=True) assert mlp.mean_only assert not any(act.children[0].mean_only for act in mlp.activations) mlp.allocate() assert all(act.children[0].mean_only for act in mlp.activations)
def test_batch_normalized_mlp_learn_scale_propagated_at_alloc(): """Test that setting learn_scale on a BatchNormalizedMLP works.""" mlp = BatchNormalizedMLP([Tanh(), Tanh()], [5, 7, 9], learn_scale=False) assert not mlp.learn_scale assert all(act.children[0].learn_scale for act in mlp.activations) mlp.allocate() assert not any(act.children[0].learn_scale for act in mlp.activations)
def test_batch_normalized_mlp_mean_only_propagated_at_alloc(): """Test that setting mean_only on a BatchNormalizedMLP works.""" mlp = BatchNormalizedMLP([Tanh(), Tanh()], [5, 7, 9], mean_only=True) assert mlp.mean_only assert not any(act.children[0].mean_only for act in mlp.activations) mlp.allocate() assert all(act.children[0].mean_only for act in mlp.activations)