示例#1
0
 def test_multiplier(self):
   mod = bias.Bias(b_init=jnp.ones)
   y = mod(jnp.ones([1, 1]), multiplier=-1)
   np.testing.assert_allclose(jnp.sum(y), 0)
示例#2
0
 def test_b_init_custom(self):
   mod = bias.Bias(b_init=jnp.ones)
   x = jnp.ones([1, 1])
   y = mod(x)
   np.testing.assert_allclose(y, x + 1)
示例#3
0
 def test_name(self):
   mod = bias.Bias(name="foo")
   self.assertEqual(mod.name, "foo")
示例#4
0
 def test_bias_dims_invalid(self):
   mod = bias.Bias(bias_dims=[1, 5])
   with self.assertRaisesRegex(ValueError,
                               "5 .* out of range for input of rank 3"):
     mod(jnp.ones([1, 2, 3]))
示例#5
0
 def test_b_init_defaults_to_zeros(self):
   mod = bias.Bias()
   x = jnp.ones([1, 1])
   y = mod(x)
   np.testing.assert_allclose(y, x)
示例#6
0
 def f():
   mod = bias.Bias(bias_dims=[-1, -2])
   mod(jnp.ones([1, 2, 3]))
   self.assertEqual(mod.bias_shape, (2, 3))
示例#7
0
 def f():
   mod = bias.Bias(bias_dims=[1, 3])
   out = mod(jnp.ones([b, d1, d2, d3]))
   self.assertEqual(mod.bias_shape, (d1, 1, d3))
   return out
示例#8
0
 def f():
   mod = bias.Bias(bias_dims=())
   return mod(jnp.ones([1, 2, 3, 4]))
示例#9
0
 def test_output_size_valid(self):
   mod = bias.Bias(output_size=(2 * 2,))
   mod(jnp.ones([2, 2 * 2]))
示例#10
0
 def test_output_shape(self):
   mod = bias.Bias(output_size=(2 * 2,))
   with self.assertRaisesRegex(ValueError, "Input shape must be [(]-1, 4[)]"):
     mod(jnp.ones([2, 2, 2]))