Beispiel #1
0
def test_batch_normalized_mlp_allocation():
    """Test that BatchNormalizedMLP performs allocation correctly."""
    mlp = BatchNormalizedMLP([Tanh(), Tanh()], [5, 7, 9])
    mlp.allocate()
    assert mlp.activations[0].children[0].input_dim == 7
    assert mlp.activations[1].children[0].input_dim == 9
    assert not any(l.use_bias for l in mlp.linear_transformations)
Beispiel #2
0
def test_batch_normalized_mlp_learn_scale_propagated_at_alloc():
    """Test that setting learn_scale on a BatchNormalizedMLP works."""
    mlp = BatchNormalizedMLP([Tanh(), Tanh()], [5, 7, 9], learn_scale=False)
    assert not mlp.learn_scale
    assert all(act.children[0].learn_scale for act in mlp.activations)
    mlp.allocate()
    assert not any(act.children[0].learn_scale for act in mlp.activations)
Beispiel #3
0
def test_batch_normalized_mlp_mean_only_propagated_at_alloc():
    """Test that setting mean_only on a BatchNormalizedMLP works."""
    mlp = BatchNormalizedMLP([Tanh(), Tanh()], [5, 7, 9], mean_only=True)
    assert mlp.mean_only
    assert not any(act.children[0].mean_only for act in mlp.activations)
    mlp.allocate()
    assert all(act.children[0].mean_only for act in mlp.activations)
Beispiel #4
0
def test_batch_normalized_mlp_learn_scale_propagated_at_alloc():
    """Test that setting learn_scale on a BatchNormalizedMLP works."""
    mlp = BatchNormalizedMLP([Tanh(), Tanh()], [5, 7, 9],
                             learn_scale=False)
    assert not mlp.learn_scale
    assert all(act.children[0].learn_scale for act in mlp.activations)
    mlp.allocate()
    assert not any(act.children[0].learn_scale for act in mlp.activations)
Beispiel #5
0
def test_batch_normalized_mlp_mean_only_propagated_at_alloc():
    """Test that setting mean_only on a BatchNormalizedMLP works."""
    mlp = BatchNormalizedMLP([Tanh(), Tanh()], [5, 7, 9],
                             mean_only=True)
    assert mlp.mean_only
    assert not any(act.children[0].mean_only for act in mlp.activations)
    mlp.allocate()
    assert all(act.children[0].mean_only for act in mlp.activations)