Esempio n. 1
0
 def test_tracer_apply(self):
     lyr = cb.Add()
     a = tracer.Tracer('a')
     b = tracer.Tracer('b')
     c = lyr @ (a, b)
     result = tracer.ApplyExpr(lyr, ('a', 'b'))
     self.assertEqual(c.expr, result)
Esempio n. 2
0
    def test_symbolic_decorator3(self):
        add_lyr = cb.Add()
        tanh_lyr = cb.Parallel(activation_fns.Relu(), activation_fns.Tanh())

        @tracer.symbolic
        def make_layer(a, b, c):
            d = add_lyr @ (a, b)
            e = add_lyr @ (d, c)
            f, g = tanh_lyr @ (d, e)
            return f, g

        layer = make_layer()  # pylint: disable=no-value-for-parameter
        a = onp.random.uniform(-10, 10, size=(2, 10))
        b = onp.random.uniform(-10, 10, size=(2, 10))
        c = onp.random.uniform(-10, 10, size=(2, 10))
        input_sd = ShapeDtype((2, 10), onp.int32)
        input_signature = (input_sd, input_sd, input_sd)
        p, s = layer.new_weights_and_state(input_signature)
        res = layer((a, b, c), weights=p, state=s, rng=jax.random.PRNGKey(0))  # pylint: disable=unexpected-keyword-arg,no-value-for-parameter,not-callable
        result0 = onp.array(res[0])
        expected0 = onp.where(a + b > 0, a + b, 0.0)
        onp.testing.assert_allclose(result0, expected0, rtol=1e-5)
        result1 = onp.array(res[1])
        expected1 = onp.tanh(a + b + c)
        onp.testing.assert_allclose(result1, expected1, rtol=1e-5)
Esempio n. 3
0
    def test_symbolic_decorator4(self):
        add_lyr = cb.Add()

        @tracer.symbolic
        def make_layer(a, b, n=1):
            for _ in range(n):
                a = add_lyr @ (a, b)
            return a

        layer = make_layer(n=3)  # pylint: disable=no-value-for-parameter
        a = onp.random.randint(0, 10, size=(2, 10))
        b = onp.random.randint(0, 10, size=(2, 10))
        input_sd = ShapeDtype((2, 10), onp.int32)
        input_signature = (input_sd, input_sd)
        p, s = layer.new_weights_and_state(input_signature)
        res = layer((a, b), weights=p, state=s, rng=jax.random.PRNGKey(0))  # pylint: disable=unexpected-keyword-arg,no-value-for-parameter,not-callable
        result = onp.array(res)
        expected = a + 3 * b
        onp.testing.assert_allclose(result, expected)

        layer = make_layer(n=5)  # pylint: disable=no-value-for-parameter
        input_sd = ShapeDtype((2, 10), onp.int32)
        input_signature = (input_sd, input_sd)
        p, s = layer.new_weights_and_state(input_signature)
        res = layer((a, b), weights=p, state=s, rng=jax.random.PRNGKey(0))  # pylint: disable=unexpected-keyword-arg,no-value-for-parameter,not-callable
        result = onp.array(res)
        expected = a + 5 * b
        onp.testing.assert_allclose(result, expected)
Esempio n. 4
0
    def test_symbolic_decorator3(self):
        add_lyr = cb.Add()
        tanh_lyr = cb.Parallel(core.Relu(), core.Tanh())

        @tracer.symbolic
        def make_layer(a, b, c):
            d = add_lyr << (a, b)
            e = add_lyr << (d, c)
            f, g = tanh_lyr << (d, e)
            return f, g

        layer = make_layer()  # pylint: disable=no-value-for-parameter
        a = onp.random.uniform(-10, 10, size=(2, 10))
        b = onp.random.uniform(-10, 10, size=(2, 10))
        c = onp.random.uniform(-10, 10, size=(2, 10))
        p, s = layer.new_params_and_state(
            ((2, 10), (2, 10), (2, 10)),
            (onp.float32, onp.float32, onp.float32),
            rng=jax.random.PRNGKey(0))
        res = layer((a, b, c), params=p, state=s, rng=jax.random.PRNGKey(0))  # pylint: disable=unexpected-keyword-arg,no-value-for-parameter,not-callable
        result0 = onp.array(res[0])
        expected0 = onp.where(a + b > 0, a + b, 0.0)
        onp.testing.assert_allclose(result0, expected0, rtol=1e-5)
        result1 = onp.array(res[1])
        expected1 = onp.tanh(a + b + c)
        onp.testing.assert_allclose(result1, expected1, rtol=1e-5)
Esempio n. 5
0
 def test_apply_to_eqn(self):
     lyr = cb.Add()
     a = tracer.Tracer('a')
     b = tracer.Tracer('b')
     c = lyr @ (a, b)
     eqns, outputs = tracer.traces_to_eqns(c)
     result0 = [tracer.ApplyEqn(lyr, ('a', 'b'), ('var0', ))]
     result1 = ('var0', )
     self.assertEqual(eqns, result0)
     self.assertEqual(outputs, result1)
Esempio n. 6
0
 def test_recombine(self):
     add_lyr = cb.Add()
     tanh_lyr = activation_fns.Tanh()
     eqns = [
         tracer.ApplyEqn(add_lyr, ('a', 'b'), ('var1', )),
         tracer.ApplyEqn(tanh_lyr, ('var1', ), ('var2', )),
     ]
     outputs = ('var2', )
     model = tracer.recombine(eqns, ('a', 'b'), outputs)
     self.assertEqual(type(model), cb.Serial)
     self.assertEqual(model.sublayers[0], add_lyr)
     self.assertEqual(model.sublayers[1], tanh_lyr)
Esempio n. 7
0
def AttentionResampling(shorten_factor, d_model, is_upsampling, d_ff, n_heads,
                        dropout, dropout_shared_axes, mode, ff_activation,
                        context_bias_layer, location_bias_layer, total_pooling,
                        resampling_fn):
    """Attention resampling."""

    attention = RelativeAttentionLMLayer(d_model,
                                         context_bias_layer,
                                         location_bias_layer,
                                         total_pooling,
                                         n_heads=n_heads,
                                         dropout=dropout,
                                         mode=mode)

    feed_forward = FeedForwardBlock(d_model, d_ff, dropout,
                                    dropout_shared_axes, mode, ff_activation)

    resampling = resampling_fn(shorten_factor, d_model, mode=mode)

    def _Dropout():
        return core.Dropout(rate=dropout,
                            shared_axes=dropout_shared_axes,
                            mode=mode)

    return [
        LayerNorm(),  # h
        cb.Branch(cb.Serial(
            resampling,
            LayerNorm(),
        ), None),  # h', h
        cb.Serial(  # pylint: disable=g-long-ternary
            cb.Select([0, 2, 1, 2]),
            cb.Add(),
        ) if is_upsampling else [],
        cb.Residual(
            cb.Select([0, 1, 1]),  # h', h, h
            attention,
            _Dropout(),
        ),
        cb.Residual(
            LayerNorm(),
            feed_forward,
            _Dropout(),
        ),
    ]
Esempio n. 8
0
    def test_symbolic_decorator1(self):
        add_lyr = cb.Add()

        @tracer.symbolic
        def make_layer(a, b, c):
            d = add_lyr @ (a, b)
            e = add_lyr @ (d, c)
            return e

        layer = make_layer()  # pylint: disable=no-value-for-parameter
        a = onp.random.randint(0, 10, size=(2, 10))
        b = onp.random.randint(0, 10, size=(2, 10))
        c = onp.random.randint(0, 10, size=(2, 10))
        input_sd = ShapeDtype((2, 10), onp.int32)
        input_signature = (input_sd, input_sd, input_sd)
        p, s = layer.new_weights_and_state(input_signature)
        res = layer((a, b, c), weights=p, state=s, rng=jax.random.PRNGKey(0))  # pylint: disable=unexpected-keyword-arg, no-value-for-parameter
        result = onp.array(res)
        expected = a + b + c
        onp.testing.assert_allclose(result, expected)
Esempio n. 9
0
    def test_symbolic_decorator2(self):
        add_lyr = cb.Add()
        tanh_lyr = core.Tanh()

        @tracer.symbolic
        def make_layer(a, b, c):
            a = tanh_lyr << a
            d = add_lyr << (a, b)
            e = add_lyr << (d, c)
            return e

        layer = make_layer()  # pylint: disable=no-value-for-parameter
        a = onp.random.randint(0, 10, size=(2, 10))
        b = onp.random.randint(0, 10, size=(2, 10))
        c = onp.random.randint(0, 10, size=(2, 10))
        p, s = layer.new_params_and_state(((2, 10), (2, 10), (2, 10)),
                                          (onp.int32, onp.int32, onp.int32),
                                          rng=jax.random.PRNGKey(0))
        res = layer((a, b, c), params=p, state=s, rng=jax.random.PRNGKey(0))  # pylint: disable=unexpected-keyword-arg, no-value-for-parameter
        result = onp.array(res)
        expected = onp.tanh(a) + b + c
        onp.testing.assert_allclose(result, expected)
Esempio n. 10
0
 def test_branch_add_div(self):
     layer = cb.Branch(cb.Add(), divide_by(0.5))
     input_signature = (ShapeDtype((3, 2)), ShapeDtype((3, 2)))
     expected_shape = ((3, 2), (3, 2))
     output_shape = base.check_shape_agreement(layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
Esempio n. 11
0
 def test_branch_name(self):
   layer = cb.Branch(cb.Add(), divide_by(0.5))
   self.assertIn('Branch', str(layer))
Esempio n. 12
0
 def test_branch_name(self):
     layer = cb.Branch(cb.Add(), divide_by(0.5))  # pylint: disable=no-value-for-parameter
     self.assertIn('Branch', str(layer))