Пример #1
0
 def test_dropout_requires_key(self):
   mod = mlp.MLP([1, 1])
   with self.assertRaisesRegex(ValueError, "rng key must be passed"):
     mod(jnp.ones([1, 1]), dropout_rate=0.5)
Пример #2
0
 def test_no_dropout_rejects_rng(self):
   mod = mlp.MLP([1, 1])
   with self.assertRaisesRegex(ValueError, "only.*when using dropout"):
     mod(jnp.ones([1, 1]), rng=jax.random.PRNGKey(42))
Пример #3
0
 def test_custom_name(self):
   mod = mlp.MLP([1], name="foobar")
   self.assertEqual(mod.name, "foobar")
Пример #4
0
 def test_reverse_override_name(self):
   mod = mlp.MLP([2, 3, 4])
   mod(jnp.ones([1, 1]))
   rev = mod.reverse(name="foobar")
   self.assertEqual(rev.name, "foobar")
Пример #5
0
 def test_passes_with_bias_to_layers(self, with_bias):
   mod = mlp.MLP([1, 1, 1], with_bias=with_bias)
   for linear in mod.layers:
     self.assertEqual(linear.with_bias, with_bias)
Пример #6
0
 def test_default_name(self):
   mod = mlp.MLP([1])
   self.assertEqual(mod.name, "mlp")
Пример #7
0
 def test_activate_final(self, num_layers):
   activation = CountingActivation()
   mod = mlp.MLP([1] * num_layers, activate_final=True, activation=activation)
   mod(jnp.ones([1, 1]))
   self.assertEqual(activation.count, num_layers)
Пример #8
0
 def test_adds_index_to_layer_names(self, num_layers):
   mod = mlp.MLP([1] * num_layers)
   for index, linear in enumerate(mod.layers):
     self.assertEqual(linear.name, "linear_%d" % index)
Пример #9
0
 def test_layers(self, num_layers):
   mod = mlp.MLP([1] * num_layers)
   self.assertLen(mod.layers, num_layers)
Пример #10
0
 def test_b_init_when_with_bias_false(self):
   with self.assertRaisesRegex(ValueError, "b_init must not be set"):
     mlp.MLP([1], with_bias=False, b_init=lambda *_: _)
Пример #11
0
def reversed_mlp(**kwargs):
  mod = mlp.MLP([2, 3, 4], **kwargs)
  mod(jnp.ones([1, 1]))
  return mod.reverse()
Пример #12
0
 def test_repr(self):
   mod = mlp.MLP([1, 2, 3])
   for index, linear in enumerate(mod.layers):
     self.assertEqual(
         repr(linear),
         "Linear(output_size={}, name='linear_{}')".format(index + 1, index))
Пример #13
0
 def test_output_size(self, output_sizes):
     mod = mlp.MLP(output_sizes)
     expected_output_size = output_sizes[-1] if output_sizes else None
     self.assertEqual(mod.output_size, expected_output_size)