def test_symbolic_decorator3(self): add_lyr = cb.Add() tanh_lyr = cb.Parallel(activation_fns.Relu(), activation_fns.Tanh()) @tracer.symbolic def make_layer(a, b, c): d = add_lyr @ (a, b) e = add_lyr @ (d, c) f, g = tanh_lyr @ (d, e) return f, g layer = make_layer() # pylint: disable=no-value-for-parameter a = onp.random.uniform(-10, 10, size=(2, 10)) b = onp.random.uniform(-10, 10, size=(2, 10)) c = onp.random.uniform(-10, 10, size=(2, 10)) input_sd = ShapeDtype((2, 10), onp.int32) input_signature = (input_sd, input_sd, input_sd) p, s = layer.new_weights_and_state(input_signature) res = layer((a, b, c), weights=p, state=s, rng=jax.random.PRNGKey(0)) # pylint: disable=unexpected-keyword-arg,no-value-for-parameter,not-callable result0 = onp.array(res[0]) expected0 = onp.where(a + b > 0, a + b, 0.0) onp.testing.assert_allclose(result0, expected0, rtol=1e-5) result1 = onp.array(res[1]) expected1 = onp.tanh(a + b + c) onp.testing.assert_allclose(result1, expected1, rtol=1e-5)
def test_input_signatures_serial_batch_norm(self): # Include a layer that actively uses state. input_signature = ShapeDtype((3, 28, 28)) batch_norm = normalization.BatchNorm() relu = activation_fns.Relu() batch_norm_and_relu = cb.Serial(batch_norm, relu) batch_norm_and_relu.init(input_signature) # Check for correct shapes entering and exiting the batch_norm layer. # And the code should run without errors. batch_norm_and_relu._set_input_signature_recursive(input_signature) self.assertEqual(batch_norm.input_signature, input_signature) self.assertEqual(relu.input_signature, input_signature)
def test_relu(self): layer = activation_fns.Relu() x = np.array([-2.0, -1.0, 0.0, 2.0, 3.0, 5.0]) self.assertEqual([0.0, 0.0, 0.0, 2.0, 3.0, 5.0], list(layer(x)))