コード例 #1
0
 def test_eqns_merge_outputs(self):
     lyr = cb.Parallel(activation_fns.Tanh(), activation_fns.Tanh())
     eqns = [
         tracer.ApplyEqn(lyr, ('a', 'b'), ('var2', )),
         tracer.IndexEqn(0, 'var2', 'var0'),
         tracer.IndexEqn(1, 'var2', 'var1')
     ]
     simple_eqns = tracer.merge_output_tuples(eqns)
     result = [tracer.ApplyEqn(lyr, ('a', 'b'), ('var0', 'var1'))]
     self.assertEqual(simple_eqns, result)
コード例 #2
0
 def test_recombine(self):
     add_lyr = cb.Add()
     tanh_lyr = activation_fns.Tanh()
     eqns = [
         tracer.ApplyEqn(add_lyr, ('a', 'b'), ('var1', )),
         tracer.ApplyEqn(tanh_lyr, ('var1', ), ('var2', )),
     ]
     outputs = ('var2', )
     model = tracer.recombine(eqns, ('a', 'b'), outputs)
     self.assertEqual(type(model), cb.Serial)
     self.assertEqual(model.sublayers[0], add_lyr)
     self.assertEqual(model.sublayers[1], tanh_lyr)
コード例 #3
0
 def test_eqns_eval_order3(self):
     dummy = activation_fns.Tanh()
     eqns = [
         tracer.ApplyEqn(dummy, ('var0', ), ('var1', 'var2', 'var3')),
         tracer.ApplyEqn(dummy, ('var1', ), ('var4', )),
         tracer.ApplyEqn(dummy, ('var2', ), ('var5', )),
         tracer.ApplyEqn(dummy, ('var3', ), ('var6', )),
     ]
     outputs = ['var4', 'var5', 'var6']
     for permuted in itertools.permutations(eqns):
         self.assertEqual(
             tracer.evaluation_order_sort(permuted, outputs)[0], eqns[0])
コード例 #4
0
 def test_eqns_eval_order2(self):
     dummy = core.Tanh()
     eqns = [
         tracer.ApplyEqn(dummy, ('var0', ), ('var1', )),
         tracer.ApplyEqn(dummy, ('var2', ), ('var3', )),
         tracer.ApplyEqn(dummy, ('var4', ), ('var5', )),
         tracer.ApplyEqn(dummy, (
             'var1',
             'var3',
             'var5',
         ), ('var6', )),
     ]
     for permuted in itertools.permutations(eqns):
         self.assertEqual(
             tracer.evaluation_order_sort(permuted, ['var6'])[-1], eqns[-1])
コード例 #5
0
 def test_apply_to_eqn(self):
     lyr = cb.Add()
     a = tracer.Tracer('a')
     b = tracer.Tracer('b')
     c = lyr @ (a, b)
     eqns, outputs = tracer.traces_to_eqns(c)
     result0 = [tracer.ApplyEqn(lyr, ('a', 'b'), ('var0', ))]
     result1 = ('var0', )
     self.assertEqual(eqns, result0)
     self.assertEqual(outputs, result1)
コード例 #6
0
 def test_eqns_eval_order1(self):
     # exhustive test of all linear order permutations for lists up to 7 long
     dummy = activation_fns.Tanh()
     for n in range(1, 7):
         eqns = [
             tracer.ApplyEqn(dummy, ('var%d' % i, ), ('var%d' % (i + 1), ))
             for i in range(n)
         ]
         for permuted in itertools.permutations(eqns):
             ordered_eqns = tracer.evaluation_order_sort(
                 permuted, ['var%d' % n])
             self.assertEqual(ordered_eqns, eqns)
コード例 #7
0
 def test_apply_index_to_eqn(self):
     lyr = cb.Parallel(activation_fns.Tanh(), activation_fns.Tanh())
     a = tracer.Tracer('a')
     b = tracer.Tracer('b')
     c, d = lyr @ (a, b)
     eqns, outputs = tracer.traces_to_eqns((c, d))
     result0 = [
         tracer.ApplyEqn(lyr, ('a', 'b'), ('var2', )),
         tracer.IndexEqn(0, 'var2', 'var0'),
         tracer.IndexEqn(1, 'var2', 'var1')
     ]
     result1 = ('var0', 'var1')
     self.assertEqual(eqns, result0)
     self.assertEqual(outputs, result1)