def test_dropout_requires_key(self): mod = mlp.MLP([1, 1]) with self.assertRaisesRegex(ValueError, "rng key must be passed"): mod(jnp.ones([1, 1]), dropout_rate=0.5)
def test_no_dropout_rejects_rng(self): mod = mlp.MLP([1, 1]) with self.assertRaisesRegex(ValueError, "only.*when using dropout"): mod(jnp.ones([1, 1]), rng=jax.random.PRNGKey(42))
def test_custom_name(self): mod = mlp.MLP([1], name="foobar") self.assertEqual(mod.name, "foobar")
def test_reverse_override_name(self): mod = mlp.MLP([2, 3, 4]) mod(jnp.ones([1, 1])) rev = mod.reverse(name="foobar") self.assertEqual(rev.name, "foobar")
def test_passes_with_bias_to_layers(self, with_bias): mod = mlp.MLP([1, 1, 1], with_bias=with_bias) for linear in mod.layers: self.assertEqual(linear.with_bias, with_bias)
def test_default_name(self): mod = mlp.MLP([1]) self.assertEqual(mod.name, "mlp")
def test_activate_final(self, num_layers): activation = CountingActivation() mod = mlp.MLP([1] * num_layers, activate_final=True, activation=activation) mod(jnp.ones([1, 1])) self.assertEqual(activation.count, num_layers)
def test_adds_index_to_layer_names(self, num_layers): mod = mlp.MLP([1] * num_layers) for index, linear in enumerate(mod.layers): self.assertEqual(linear.name, "linear_%d" % index)
def test_layers(self, num_layers): mod = mlp.MLP([1] * num_layers) self.assertLen(mod.layers, num_layers)
def test_b_init_when_with_bias_false(self): with self.assertRaisesRegex(ValueError, "b_init must not be set"): mlp.MLP([1], with_bias=False, b_init=lambda *_: _)
def reversed_mlp(**kwargs): mod = mlp.MLP([2, 3, 4], **kwargs) mod(jnp.ones([1, 1])) return mod.reverse()
def test_repr(self): mod = mlp.MLP([1, 2, 3]) for index, linear in enumerate(mod.layers): self.assertEqual( repr(linear), "Linear(output_size={}, name='linear_{}')".format(index + 1, index))
def test_output_size(self, output_sizes): mod = mlp.MLP(output_sizes) expected_output_size = output_sizes[-1] if output_sizes else None self.assertEqual(mod.output_size, expected_output_size)