def test_tracer_apply(self): lyr = cb.Add() a = tracer.Tracer('a') b = tracer.Tracer('b') c = lyr @ (a, b) result = tracer.ApplyExpr(lyr, ('a', 'b')) self.assertEqual(c.expr, result)
def test_symbolic_decorator3(self): add_lyr = cb.Add() tanh_lyr = cb.Parallel(activation_fns.Relu(), activation_fns.Tanh()) @tracer.symbolic def make_layer(a, b, c): d = add_lyr @ (a, b) e = add_lyr @ (d, c) f, g = tanh_lyr @ (d, e) return f, g layer = make_layer() # pylint: disable=no-value-for-parameter a = onp.random.uniform(-10, 10, size=(2, 10)) b = onp.random.uniform(-10, 10, size=(2, 10)) c = onp.random.uniform(-10, 10, size=(2, 10)) input_sd = ShapeDtype((2, 10), onp.int32) input_signature = (input_sd, input_sd, input_sd) p, s = layer.new_weights_and_state(input_signature) res = layer((a, b, c), weights=p, state=s, rng=jax.random.PRNGKey(0)) # pylint: disable=unexpected-keyword-arg,no-value-for-parameter,not-callable result0 = onp.array(res[0]) expected0 = onp.where(a + b > 0, a + b, 0.0) onp.testing.assert_allclose(result0, expected0, rtol=1e-5) result1 = onp.array(res[1]) expected1 = onp.tanh(a + b + c) onp.testing.assert_allclose(result1, expected1, rtol=1e-5)
def test_symbolic_decorator4(self): add_lyr = cb.Add() @tracer.symbolic def make_layer(a, b, n=1): for _ in range(n): a = add_lyr @ (a, b) return a layer = make_layer(n=3) # pylint: disable=no-value-for-parameter a = onp.random.randint(0, 10, size=(2, 10)) b = onp.random.randint(0, 10, size=(2, 10)) input_sd = ShapeDtype((2, 10), onp.int32) input_signature = (input_sd, input_sd) p, s = layer.new_weights_and_state(input_signature) res = layer((a, b), weights=p, state=s, rng=jax.random.PRNGKey(0)) # pylint: disable=unexpected-keyword-arg,no-value-for-parameter,not-callable result = onp.array(res) expected = a + 3 * b onp.testing.assert_allclose(result, expected) layer = make_layer(n=5) # pylint: disable=no-value-for-parameter input_sd = ShapeDtype((2, 10), onp.int32) input_signature = (input_sd, input_sd) p, s = layer.new_weights_and_state(input_signature) res = layer((a, b), weights=p, state=s, rng=jax.random.PRNGKey(0)) # pylint: disable=unexpected-keyword-arg,no-value-for-parameter,not-callable result = onp.array(res) expected = a + 5 * b onp.testing.assert_allclose(result, expected)
def test_symbolic_decorator3(self): add_lyr = cb.Add() tanh_lyr = cb.Parallel(core.Relu(), core.Tanh()) @tracer.symbolic def make_layer(a, b, c): d = add_lyr << (a, b) e = add_lyr << (d, c) f, g = tanh_lyr << (d, e) return f, g layer = make_layer() # pylint: disable=no-value-for-parameter a = onp.random.uniform(-10, 10, size=(2, 10)) b = onp.random.uniform(-10, 10, size=(2, 10)) c = onp.random.uniform(-10, 10, size=(2, 10)) p, s = layer.new_params_and_state( ((2, 10), (2, 10), (2, 10)), (onp.float32, onp.float32, onp.float32), rng=jax.random.PRNGKey(0)) res = layer((a, b, c), params=p, state=s, rng=jax.random.PRNGKey(0)) # pylint: disable=unexpected-keyword-arg,no-value-for-parameter,not-callable result0 = onp.array(res[0]) expected0 = onp.where(a + b > 0, a + b, 0.0) onp.testing.assert_allclose(result0, expected0, rtol=1e-5) result1 = onp.array(res[1]) expected1 = onp.tanh(a + b + c) onp.testing.assert_allclose(result1, expected1, rtol=1e-5)
def test_apply_to_eqn(self): lyr = cb.Add() a = tracer.Tracer('a') b = tracer.Tracer('b') c = lyr @ (a, b) eqns, outputs = tracer.traces_to_eqns(c) result0 = [tracer.ApplyEqn(lyr, ('a', 'b'), ('var0', ))] result1 = ('var0', ) self.assertEqual(eqns, result0) self.assertEqual(outputs, result1)
def test_recombine(self): add_lyr = cb.Add() tanh_lyr = activation_fns.Tanh() eqns = [ tracer.ApplyEqn(add_lyr, ('a', 'b'), ('var1', )), tracer.ApplyEqn(tanh_lyr, ('var1', ), ('var2', )), ] outputs = ('var2', ) model = tracer.recombine(eqns, ('a', 'b'), outputs) self.assertEqual(type(model), cb.Serial) self.assertEqual(model.sublayers[0], add_lyr) self.assertEqual(model.sublayers[1], tanh_lyr)
def AttentionResampling(shorten_factor, d_model, is_upsampling, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, context_bias_layer, location_bias_layer, total_pooling, resampling_fn): """Attention resampling.""" attention = RelativeAttentionLMLayer(d_model, context_bias_layer, location_bias_layer, total_pooling, n_heads=n_heads, dropout=dropout, mode=mode) feed_forward = FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation) resampling = resampling_fn(shorten_factor, d_model, mode=mode) def _Dropout(): return core.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) return [ LayerNorm(), # h cb.Branch(cb.Serial( resampling, LayerNorm(), ), None), # h', h cb.Serial( # pylint: disable=g-long-ternary cb.Select([0, 2, 1, 2]), cb.Add(), ) if is_upsampling else [], cb.Residual( cb.Select([0, 1, 1]), # h', h, h attention, _Dropout(), ), cb.Residual( LayerNorm(), feed_forward, _Dropout(), ), ]
def test_symbolic_decorator1(self): add_lyr = cb.Add() @tracer.symbolic def make_layer(a, b, c): d = add_lyr @ (a, b) e = add_lyr @ (d, c) return e layer = make_layer() # pylint: disable=no-value-for-parameter a = onp.random.randint(0, 10, size=(2, 10)) b = onp.random.randint(0, 10, size=(2, 10)) c = onp.random.randint(0, 10, size=(2, 10)) input_sd = ShapeDtype((2, 10), onp.int32) input_signature = (input_sd, input_sd, input_sd) p, s = layer.new_weights_and_state(input_signature) res = layer((a, b, c), weights=p, state=s, rng=jax.random.PRNGKey(0)) # pylint: disable=unexpected-keyword-arg, no-value-for-parameter result = onp.array(res) expected = a + b + c onp.testing.assert_allclose(result, expected)
def test_symbolic_decorator2(self): add_lyr = cb.Add() tanh_lyr = core.Tanh() @tracer.symbolic def make_layer(a, b, c): a = tanh_lyr << a d = add_lyr << (a, b) e = add_lyr << (d, c) return e layer = make_layer() # pylint: disable=no-value-for-parameter a = onp.random.randint(0, 10, size=(2, 10)) b = onp.random.randint(0, 10, size=(2, 10)) c = onp.random.randint(0, 10, size=(2, 10)) p, s = layer.new_params_and_state(((2, 10), (2, 10), (2, 10)), (onp.int32, onp.int32, onp.int32), rng=jax.random.PRNGKey(0)) res = layer((a, b, c), params=p, state=s, rng=jax.random.PRNGKey(0)) # pylint: disable=unexpected-keyword-arg, no-value-for-parameter result = onp.array(res) expected = onp.tanh(a) + b + c onp.testing.assert_allclose(result, expected)
def test_branch_add_div(self): layer = cb.Branch(cb.Add(), divide_by(0.5)) input_signature = (ShapeDtype((3, 2)), ShapeDtype((3, 2))) expected_shape = ((3, 2), (3, 2)) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape)
def test_branch_name(self): layer = cb.Branch(cb.Add(), divide_by(0.5)) self.assertIn('Branch', str(layer))
def test_branch_name(self): layer = cb.Branch(cb.Add(), divide_by(0.5)) # pylint: disable=no-value-for-parameter self.assertIn('Branch', str(layer))