def test_symbolic_decorator3(self): add_lyr = cb.Add() tanh_lyr = cb.Parallel(core.Relu(), core.Tanh()) @tracer.symbolic def make_layer(a, b, c): d = add_lyr << (a, b) e = add_lyr << (d, c) f, g = tanh_lyr << (d, e) return f, g layer = make_layer() # pylint: disable=no-value-for-parameter a = onp.random.uniform(-10, 10, size=(2, 10)) b = onp.random.uniform(-10, 10, size=(2, 10)) c = onp.random.uniform(-10, 10, size=(2, 10)) input_sd = ShapeDtype((2, 10), onp.int32) input_signature = (input_sd, input_sd, input_sd) p, s = layer.new_weights_and_state(input_signature) res = layer((a, b, c), weights=p, state=s, rng=jax.random.PRNGKey(0)) # pylint: disable=unexpected-keyword-arg,no-value-for-parameter,not-callable result0 = onp.array(res[0]) expected0 = onp.where(a + b > 0, a + b, 0.0) onp.testing.assert_allclose(result0, expected0, rtol=1e-5) result1 = onp.array(res[1]) expected1 = onp.tanh(a + b + c) onp.testing.assert_allclose(result1, expected1, rtol=1e-5)
def test_input_signatures_serial_batch_norm(self): # Include a layer that actively uses state. batch_norm = normalization.BatchNorm() relu = core.Relu() batch_norm_and_relu = cb.Serial(batch_norm, relu) # Check for correct shapes entering and exiting the batch_norm layer. # And the code should run without errors. batch_norm_and_relu.input_signature = ShapeDtype((3, 28, 28)) self.assertEqual(batch_norm.input_signature, ShapeDtype((3, 28, 28))) self.assertEqual(relu.input_signature, ShapeDtype((3, 28, 28)))