Exemplo n.º 1
0
 def test_apply_index_to_eqn(self):
     lyr = cb.Parallel(core.Tanh(), core.Tanh())
     a = tracer.Tracer('a')
     b = tracer.Tracer('b')
     c, d = lyr << (a, b)
     eqns, outputs = tracer.traces_to_eqns((c, d))
     result0 = [
         tracer.ApplyEqn(lyr, ('a', 'b'), ('var2', )),
         tracer.IndexEqn(0, 'var2', 'var0'),
         tracer.IndexEqn(1, 'var2', 'var1')
     ]
     result1 = ('var0', 'var1')
     self.assertEqual(eqns, result0)
     self.assertEqual(outputs, result1)
Exemplo n.º 2
0
    def test_symbolic_decorator3(self):
        add_lyr = cb.Add()
        tanh_lyr = cb.Parallel(core.Relu(), core.Tanh())

        @tracer.symbolic
        def make_layer(a, b, c):
            d = add_lyr << (a, b)
            e = add_lyr << (d, c)
            f, g = tanh_lyr << (d, e)
            return f, g

        layer = make_layer()  # pylint: disable=no-value-for-parameter
        a = onp.random.uniform(-10, 10, size=(2, 10))
        b = onp.random.uniform(-10, 10, size=(2, 10))
        c = onp.random.uniform(-10, 10, size=(2, 10))
        input_sd = ShapeDtype((2, 10), onp.int32)
        input_signature = (input_sd, input_sd, input_sd)
        p, s = layer.new_weights_and_state(input_signature)
        res = layer((a, b, c), weights=p, state=s, rng=jax.random.PRNGKey(0))  # pylint: disable=unexpected-keyword-arg,no-value-for-parameter,not-callable
        result0 = onp.array(res[0])
        expected0 = onp.where(a + b > 0, a + b, 0.0)
        onp.testing.assert_allclose(result0, expected0, rtol=1e-5)
        result1 = onp.array(res[1])
        expected1 = onp.tanh(a + b + c)
        onp.testing.assert_allclose(result1, expected1, rtol=1e-5)
Exemplo n.º 3
0
 def test_recombine(self):
     add_lyr = cb.Add()
     tanh_lyr = core.Tanh()
     eqns = [
         tracer.ApplyEqn(add_lyr, ('a', 'b'), ('var1', )),
         tracer.ApplyEqn(tanh_lyr, ('var1', ), ('var2', )),
     ]
     outputs = ('var2', )
     model = tracer.recombine(eqns, ('a', 'b'), outputs)
     self.assertEqual(type(model), cb.Serial)
     self.assertEqual(model.sublayers[0], add_lyr)
     self.assertEqual(model.sublayers[1], tanh_lyr)
Exemplo n.º 4
0
 def test_eqns_eval_order3(self):
     dummy = core.Tanh()
     eqns = [
         tracer.ApplyEqn(dummy, ('var0', ), ('var1', 'var2', 'var3')),
         tracer.ApplyEqn(dummy, ('var1', ), ('var4', )),
         tracer.ApplyEqn(dummy, ('var2', ), ('var5', )),
         tracer.ApplyEqn(dummy, ('var3', ), ('var6', )),
     ]
     outputs = ['var4', 'var5', 'var6']
     for permuted in itertools.permutations(eqns):
         self.assertEqual(
             tracer.evaluation_order_sort(permuted, outputs)[0], eqns[0])
Exemplo n.º 5
0
 def test_eqns_eval_order1(self):
     # exhustive test of all linear order permutations for lists up to 7 long
     dummy = core.Tanh()
     for n in range(1, 7):
         eqns = [
             tracer.ApplyEqn(dummy, ('var%d' % i, ), ('var%d' % (i + 1), ))
             for i in range(n)
         ]
         for permuted in itertools.permutations(eqns):
             ordered_eqns = tracer.evaluation_order_sort(
                 permuted, ['var%d' % n])
             self.assertEqual(ordered_eqns, eqns)
Exemplo n.º 6
0
    def test_symbolic_decorator2(self):
        add_lyr = cb.Add()
        tanh_lyr = core.Tanh()

        @tracer.symbolic
        def make_layer(a, b, c):
            a = tanh_lyr << a
            d = add_lyr << (a, b)
            e = add_lyr << (d, c)
            return e

        layer = make_layer()  # pylint: disable=no-value-for-parameter
        a = onp.random.randint(0, 10, size=(2, 10))
        b = onp.random.randint(0, 10, size=(2, 10))
        c = onp.random.randint(0, 10, size=(2, 10))
        input_sd = ShapeDtype((2, 10), onp.int32)
        input_signature = (input_sd, input_sd, input_sd)
        p, s = layer.new_weights_and_state(input_signature)
        res = layer((a, b, c), weights=p, state=s, rng=jax.random.PRNGKey(0))  # pylint: disable=unexpected-keyword-arg, no-value-for-parameter
        result = onp.array(res)
        expected = onp.tanh(a) + b + c
        onp.testing.assert_allclose(result, expected)