def test_new_rngs_deterministic(self): layer1 = base.Layer() layer2 = base.Layer(n_in=2, n_out=2) rng1, rng2 = layer1.new_rngs(2) rng3, rng4 = layer2.new_rngs(2) self.assertEqual(rng1.tolist(), rng3.tolist()) self.assertEqual(rng2.tolist(), rng4.tolist())
def test_custom_name(self): layer = base.Layer() self.assertIn('Layer', str(layer)) self.assertNotIn('CustomLayer', str(layer)) layer = base.Layer(name='CustomLayer') self.assertIn('CustomLayer', str(layer))
def test_new_rng_deterministic(self): input_signature = ShapeDtype((2, 3, 5)) layer1 = base.Layer() layer2 = base.Layer(n_in=2, n_out=2) _, _ = layer1.init(input_signature) _, _ = layer2.init(input_signature) rng1 = layer1.new_rng() rng2 = layer2.new_rng() self.assertEqual(rng1.tolist(), rng2.tolist())
def test_new_rngs_deterministic(self): inputs1 = ShapeDtype((2, 3, 5)) inputs2 = (ShapeDtype((2, 3, 5)), ShapeDtype((2, 3, 5))) layer1 = base.Layer() layer2 = base.Layer(n_in=2, n_out=2) _, _ = layer1.init(inputs1) _, _ = layer2.init(inputs2) rng1, rng2 = layer1.new_rngs(2) rng3, rng4 = layer2.new_rngs(2) self.assertEqual(rng1.tolist(), rng3.tolist()) self.assertEqual(rng2.tolist(), rng4.tolist())
def test_forward_with_state_raises_error(self): layer = base.Layer() x = np.array([[1, 2, 3, 4, 5], [10, 20, 30, 40, 50]]) with self.assertRaises(NotImplementedError): _, _ = layer.forward_with_state( x, base.EMPTY_WEIGHTS, base.EMPTY_STATE, None)
def test_new_rng_new_value_each_call(self): layer = base.Layer() rng1 = layer.new_rng() rng2 = layer.new_rng() rng3 = layer.new_rng() self.assertNotEqual(rng1.tolist(), rng2.tolist()) self.assertNotEqual(rng2.tolist(), rng3.tolist())
def test_new_rngs_new_values_each_call(self): layer = base.Layer() rng1, rng2 = layer.new_rngs(2) rng3, rng4 = layer.new_rngs(2) self.assertNotEqual(rng1.tolist(), rng2.tolist()) self.assertNotEqual(rng3.tolist(), rng4.tolist()) self.assertNotEqual(rng1.tolist(), rng3.tolist()) self.assertNotEqual(rng2.tolist(), rng4.tolist())
def test_new_rng_new_value_each_call(self): input_signature = ShapeDtype((2, 3, 5)) layer = base.Layer() _, _ = layer.init(input_signature) rng1 = layer.new_rng() rng2 = layer.new_rng() rng3 = layer.new_rng() self.assertNotEqual(rng1.tolist(), rng2.tolist()) self.assertNotEqual(rng2.tolist(), rng3.tolist())
def test_new_rngs_new_values_each_call(self): input_signature = shapes.ShapeDtype((2, 3, 5)) layer = base.Layer() _, _ = layer.init(input_signature) rng1, rng2 = layer.new_rngs(2) rng3, rng4 = layer.new_rngs(2) self.assertNotEqual(rng1.tolist(), rng2.tolist()) self.assertNotEqual(rng3.tolist(), rng4.tolist()) self.assertNotEqual(rng1.tolist(), rng3.tolist()) self.assertNotEqual(rng2.tolist(), rng4.tolist())
def test_custom_name(self): layer = base.Layer() self.assertIn('Layer', str(layer)) self.assertNotIn('CustomLayer', str(layer)) layer = base.Layer(name='CustomLayer') self.assertIn('CustomLayer', str(layer)) @base.layer() def DefaultDecoratorLayer(x, **unused_kwargs): return x layer = DefaultDecoratorLayer() # pylint: disable=no-value-for-parameter self.assertIn('DefaultDecoratorLayer', str(layer)) @base.layer(name='CustomDecoratorLayer') def NotDefaultDecoratorLayer(x, **unused_kwargs): return x layer = NotDefaultDecoratorLayer() # pylint: disable=no-value-for-parameter self.assertIn('CustomDecoratorLayer', str(layer))
def test_init_returns_empty_weights_and_state(self): layer = base.Layer() input_signature = shapes.ShapeDtype((2, 5)) weights, state = layer.init(input_signature) self.assertEmpty(weights) self.assertEmpty(state)
def test_new_weights_returns_empty(self): layer = base.Layer() input_signature = shapes.ShapeDtype((2, 5)) weights = layer.new_weights(input_signature) self.assertEmpty(weights)
def test_call_raises_error(self): layer = base.Layer() x = np.array([[1, 2, 3, 4, 5], [10, 20, 30, 40, 50]]) with self.assertRaisesRegex(base.LayerError, 'NotImplementedError'): _ = layer(x)
def test_forward_raises_error(self): layer = base.Layer() x = np.array([[1, 2, 3, 4, 5], [10, 20, 30, 40, 50]]) with self.assertRaises(NotImplementedError): _ = layer.forward(x)
def test_new_rng_deterministic(self): layer1 = base.Layer() layer2 = base.Layer(n_in=2, n_out=2) rng1 = layer1.new_rng() rng2 = layer2.new_rng() self.assertEqual(rng1.tolist(), rng2.tolist())