Пример #1
0
    def test_symbolic_decorator3(self):
        add_lyr = cb.Add()
        tanh_lyr = cb.Parallel(activation_fns.Relu(), activation_fns.Tanh())

        @tracer.symbolic
        def make_layer(a, b, c):
            d = add_lyr @ (a, b)
            e = add_lyr @ (d, c)
            f, g = tanh_lyr @ (d, e)
            return f, g

        layer = make_layer()  # pylint: disable=no-value-for-parameter
        a = onp.random.uniform(-10, 10, size=(2, 10))
        b = onp.random.uniform(-10, 10, size=(2, 10))
        c = onp.random.uniform(-10, 10, size=(2, 10))
        input_sd = ShapeDtype((2, 10), onp.int32)
        input_signature = (input_sd, input_sd, input_sd)
        p, s = layer.new_weights_and_state(input_signature)
        res = layer((a, b, c), weights=p, state=s, rng=jax.random.PRNGKey(0))  # pylint: disable=unexpected-keyword-arg,no-value-for-parameter,not-callable
        result0 = onp.array(res[0])
        expected0 = onp.where(a + b > 0, a + b, 0.0)
        onp.testing.assert_allclose(result0, expected0, rtol=1e-5)
        result1 = onp.array(res[1])
        expected1 = onp.tanh(a + b + c)
        onp.testing.assert_allclose(result1, expected1, rtol=1e-5)
Пример #2
0
  def test_input_signatures_serial_batch_norm(self):
    # Include a layer that actively uses state.
    input_signature = ShapeDtype((3, 28, 28))
    batch_norm = normalization.BatchNorm()
    relu = activation_fns.Relu()
    batch_norm_and_relu = cb.Serial(batch_norm, relu)
    batch_norm_and_relu.init(input_signature)

    # Check for correct shapes entering and exiting the batch_norm layer.
    # And the code should run without errors.
    batch_norm_and_relu._set_input_signature_recursive(input_signature)
    self.assertEqual(batch_norm.input_signature, input_signature)
    self.assertEqual(relu.input_signature, input_signature)
Пример #3
0
 def test_relu(self):
     layer = activation_fns.Relu()
     x = np.array([-2.0, -1.0, 0.0, 2.0, 3.0, 5.0])
     self.assertEqual([0.0, 0.0, 0.0, 2.0, 3.0, 5.0], list(layer(x)))