def test_multiplier(self): mod = bias.Bias(b_init=jnp.ones) y = mod(jnp.ones([1, 1]), multiplier=-1) np.testing.assert_allclose(jnp.sum(y), 0)
def test_b_init_custom(self): mod = bias.Bias(b_init=jnp.ones) x = jnp.ones([1, 1]) y = mod(x) np.testing.assert_allclose(y, x + 1)
def test_name(self): mod = bias.Bias(name="foo") self.assertEqual(mod.name, "foo")
def test_bias_dims_invalid(self): mod = bias.Bias(bias_dims=[1, 5]) with self.assertRaisesRegex(ValueError, "5 .* out of range for input of rank 3"): mod(jnp.ones([1, 2, 3]))
def test_b_init_defaults_to_zeros(self): mod = bias.Bias() x = jnp.ones([1, 1]) y = mod(x) np.testing.assert_allclose(y, x)
def f(): mod = bias.Bias(bias_dims=[-1, -2]) mod(jnp.ones([1, 2, 3])) self.assertEqual(mod.bias_shape, (2, 3))
def f(): mod = bias.Bias(bias_dims=[1, 3]) out = mod(jnp.ones([b, d1, d2, d3])) self.assertEqual(mod.bias_shape, (d1, 1, d3)) return out
def f(): mod = bias.Bias(bias_dims=()) return mod(jnp.ones([1, 2, 3, 4]))
def test_output_size_valid(self): mod = bias.Bias(output_size=(2 * 2,)) mod(jnp.ones([2, 2 * 2]))
def test_output_shape(self): mod = bias.Bias(output_size=(2 * 2,)) with self.assertRaisesRegex(ValueError, "Input shape must be [(]-1, 4[)]"): mod(jnp.ones([2, 2, 2]))