예제 #1
0
    def test_symbolic_decorator3(self):
        add_lyr = cb.Add()
        tanh_lyr = cb.Parallel(core.Relu(), core.Tanh())

        @tracer.symbolic
        def make_layer(a, b, c):
            d = add_lyr << (a, b)
            e = add_lyr << (d, c)
            f, g = tanh_lyr << (d, e)
            return f, g

        layer = make_layer()  # pylint: disable=no-value-for-parameter
        a = onp.random.uniform(-10, 10, size=(2, 10))
        b = onp.random.uniform(-10, 10, size=(2, 10))
        c = onp.random.uniform(-10, 10, size=(2, 10))
        input_sd = ShapeDtype((2, 10), onp.int32)
        input_signature = (input_sd, input_sd, input_sd)
        p, s = layer.new_weights_and_state(input_signature)
        res = layer((a, b, c), weights=p, state=s, rng=jax.random.PRNGKey(0))  # pylint: disable=unexpected-keyword-arg,no-value-for-parameter,not-callable
        result0 = onp.array(res[0])
        expected0 = onp.where(a + b > 0, a + b, 0.0)
        onp.testing.assert_allclose(result0, expected0, rtol=1e-5)
        result1 = onp.array(res[1])
        expected1 = onp.tanh(a + b + c)
        onp.testing.assert_allclose(result1, expected1, rtol=1e-5)
    def test_input_signatures_serial_batch_norm(self):
        # Include a layer that actively uses state.
        batch_norm = normalization.BatchNorm()
        relu = core.Relu()
        batch_norm_and_relu = cb.Serial(batch_norm, relu)

        # Check for correct shapes entering and exiting the batch_norm layer.
        # And the code should run without errors.
        batch_norm_and_relu.input_signature = ShapeDtype((3, 28, 28))
        self.assertEqual(batch_norm.input_signature, ShapeDtype((3, 28, 28)))
        self.assertEqual(relu.input_signature, ShapeDtype((3, 28, 28)))