def test_new_rngs_deterministic(self):
     layer1 = base.Layer()
     layer2 = base.Layer(n_in=2, n_out=2)
     rng1, rng2 = layer1.new_rngs(2)
     rng3, rng4 = layer2.new_rngs(2)
     self.assertEqual(rng1.tolist(), rng3.tolist())
     self.assertEqual(rng2.tolist(), rng4.tolist())
Exemple #2
0
  def test_custom_name(self):
    layer = base.Layer()
    self.assertIn('Layer', str(layer))
    self.assertNotIn('CustomLayer', str(layer))

    layer = base.Layer(name='CustomLayer')
    self.assertIn('CustomLayer', str(layer))
Exemple #3
0
 def test_new_rng_deterministic(self):
     input_signature = ShapeDtype((2, 3, 5))
     layer1 = base.Layer()
     layer2 = base.Layer(n_in=2, n_out=2)
     _, _ = layer1.init(input_signature)
     _, _ = layer2.init(input_signature)
     rng1 = layer1.new_rng()
     rng2 = layer2.new_rng()
     self.assertEqual(rng1.tolist(), rng2.tolist())
Exemple #4
0
 def test_new_rngs_deterministic(self):
     inputs1 = ShapeDtype((2, 3, 5))
     inputs2 = (ShapeDtype((2, 3, 5)), ShapeDtype((2, 3, 5)))
     layer1 = base.Layer()
     layer2 = base.Layer(n_in=2, n_out=2)
     _, _ = layer1.init(inputs1)
     _, _ = layer2.init(inputs2)
     rng1, rng2 = layer1.new_rngs(2)
     rng3, rng4 = layer2.new_rngs(2)
     self.assertEqual(rng1.tolist(), rng3.tolist())
     self.assertEqual(rng2.tolist(), rng4.tolist())
Exemple #5
0
 def test_forward_with_state_raises_error(self):
   layer = base.Layer()
   x = np.array([[1, 2, 3, 4, 5],
                 [10, 20, 30, 40, 50]])
   with self.assertRaises(NotImplementedError):
     _, _ = layer.forward_with_state(
         x, base.EMPTY_WEIGHTS, base.EMPTY_STATE, None)
 def test_new_rng_new_value_each_call(self):
     layer = base.Layer()
     rng1 = layer.new_rng()
     rng2 = layer.new_rng()
     rng3 = layer.new_rng()
     self.assertNotEqual(rng1.tolist(), rng2.tolist())
     self.assertNotEqual(rng2.tolist(), rng3.tolist())
 def test_new_rngs_new_values_each_call(self):
     layer = base.Layer()
     rng1, rng2 = layer.new_rngs(2)
     rng3, rng4 = layer.new_rngs(2)
     self.assertNotEqual(rng1.tolist(), rng2.tolist())
     self.assertNotEqual(rng3.tolist(), rng4.tolist())
     self.assertNotEqual(rng1.tolist(), rng3.tolist())
     self.assertNotEqual(rng2.tolist(), rng4.tolist())
Exemple #8
0
 def test_new_rng_new_value_each_call(self):
     input_signature = ShapeDtype((2, 3, 5))
     layer = base.Layer()
     _, _ = layer.init(input_signature)
     rng1 = layer.new_rng()
     rng2 = layer.new_rng()
     rng3 = layer.new_rng()
     self.assertNotEqual(rng1.tolist(), rng2.tolist())
     self.assertNotEqual(rng2.tolist(), rng3.tolist())
Exemple #9
0
 def test_new_rngs_new_values_each_call(self):
   input_signature = shapes.ShapeDtype((2, 3, 5))
   layer = base.Layer()
   _, _ = layer.init(input_signature)
   rng1, rng2 = layer.new_rngs(2)
   rng3, rng4 = layer.new_rngs(2)
   self.assertNotEqual(rng1.tolist(), rng2.tolist())
   self.assertNotEqual(rng3.tolist(), rng4.tolist())
   self.assertNotEqual(rng1.tolist(), rng3.tolist())
   self.assertNotEqual(rng2.tolist(), rng4.tolist())
Exemple #10
0
    def test_custom_name(self):
        layer = base.Layer()
        self.assertIn('Layer', str(layer))
        self.assertNotIn('CustomLayer', str(layer))

        layer = base.Layer(name='CustomLayer')
        self.assertIn('CustomLayer', str(layer))

        @base.layer()
        def DefaultDecoratorLayer(x, **unused_kwargs):
            return x

        layer = DefaultDecoratorLayer()  # pylint: disable=no-value-for-parameter
        self.assertIn('DefaultDecoratorLayer', str(layer))

        @base.layer(name='CustomDecoratorLayer')
        def NotDefaultDecoratorLayer(x, **unused_kwargs):
            return x

        layer = NotDefaultDecoratorLayer()  # pylint: disable=no-value-for-parameter
        self.assertIn('CustomDecoratorLayer', str(layer))
Exemple #11
0
 def test_init_returns_empty_weights_and_state(self):
   layer = base.Layer()
   input_signature = shapes.ShapeDtype((2, 5))
   weights, state = layer.init(input_signature)
   self.assertEmpty(weights)
   self.assertEmpty(state)
Exemple #12
0
 def test_new_weights_returns_empty(self):
   layer = base.Layer()
   input_signature = shapes.ShapeDtype((2, 5))
   weights = layer.new_weights(input_signature)
   self.assertEmpty(weights)
Exemple #13
0
 def test_call_raises_error(self):
   layer = base.Layer()
   x = np.array([[1, 2, 3, 4, 5],
                 [10, 20, 30, 40, 50]])
   with self.assertRaisesRegex(base.LayerError, 'NotImplementedError'):
     _ = layer(x)
Exemple #14
0
 def test_forward_raises_error(self):
   layer = base.Layer()
   x = np.array([[1, 2, 3, 4, 5],
                 [10, 20, 30, 40, 50]])
   with self.assertRaises(NotImplementedError):
     _ = layer.forward(x)
 def test_new_rng_deterministic(self):
     layer1 = base.Layer()
     layer2 = base.Layer(n_in=2, n_out=2)
     rng1 = layer1.new_rng()
     rng2 = layer2.new_rng()
     self.assertEqual(rng1.tolist(), rng2.tolist())