def test_eqns_merge_outputs(self): lyr = cb.Parallel(activation_fns.Tanh(), activation_fns.Tanh()) eqns = [ tracer.ApplyEqn(lyr, ('a', 'b'), ('var2', )), tracer.IndexEqn(0, 'var2', 'var0'), tracer.IndexEqn(1, 'var2', 'var1') ] simple_eqns = tracer.merge_output_tuples(eqns) result = [tracer.ApplyEqn(lyr, ('a', 'b'), ('var0', 'var1'))] self.assertEqual(simple_eqns, result)
def test_recombine(self): add_lyr = cb.Add() tanh_lyr = activation_fns.Tanh() eqns = [ tracer.ApplyEqn(add_lyr, ('a', 'b'), ('var1', )), tracer.ApplyEqn(tanh_lyr, ('var1', ), ('var2', )), ] outputs = ('var2', ) model = tracer.recombine(eqns, ('a', 'b'), outputs) self.assertEqual(type(model), cb.Serial) self.assertEqual(model.sublayers[0], add_lyr) self.assertEqual(model.sublayers[1], tanh_lyr)
def test_eqns_eval_order3(self): dummy = activation_fns.Tanh() eqns = [ tracer.ApplyEqn(dummy, ('var0', ), ('var1', 'var2', 'var3')), tracer.ApplyEqn(dummy, ('var1', ), ('var4', )), tracer.ApplyEqn(dummy, ('var2', ), ('var5', )), tracer.ApplyEqn(dummy, ('var3', ), ('var6', )), ] outputs = ['var4', 'var5', 'var6'] for permuted in itertools.permutations(eqns): self.assertEqual( tracer.evaluation_order_sort(permuted, outputs)[0], eqns[0])
def test_eqns_eval_order2(self): dummy = core.Tanh() eqns = [ tracer.ApplyEqn(dummy, ('var0', ), ('var1', )), tracer.ApplyEqn(dummy, ('var2', ), ('var3', )), tracer.ApplyEqn(dummy, ('var4', ), ('var5', )), tracer.ApplyEqn(dummy, ( 'var1', 'var3', 'var5', ), ('var6', )), ] for permuted in itertools.permutations(eqns): self.assertEqual( tracer.evaluation_order_sort(permuted, ['var6'])[-1], eqns[-1])
def test_apply_to_eqn(self): lyr = cb.Add() a = tracer.Tracer('a') b = tracer.Tracer('b') c = lyr @ (a, b) eqns, outputs = tracer.traces_to_eqns(c) result0 = [tracer.ApplyEqn(lyr, ('a', 'b'), ('var0', ))] result1 = ('var0', ) self.assertEqual(eqns, result0) self.assertEqual(outputs, result1)
def test_eqns_eval_order1(self): # exhustive test of all linear order permutations for lists up to 7 long dummy = activation_fns.Tanh() for n in range(1, 7): eqns = [ tracer.ApplyEqn(dummy, ('var%d' % i, ), ('var%d' % (i + 1), )) for i in range(n) ] for permuted in itertools.permutations(eqns): ordered_eqns = tracer.evaluation_order_sort( permuted, ['var%d' % n]) self.assertEqual(ordered_eqns, eqns)
def test_apply_index_to_eqn(self): lyr = cb.Parallel(activation_fns.Tanh(), activation_fns.Tanh()) a = tracer.Tracer('a') b = tracer.Tracer('b') c, d = lyr @ (a, b) eqns, outputs = tracer.traces_to_eqns((c, d)) result0 = [ tracer.ApplyEqn(lyr, ('a', 'b'), ('var2', )), tracer.IndexEqn(0, 'var2', 'var0'), tracer.IndexEqn(1, 'var2', 'var1') ] result1 = ('var0', 'var1') self.assertEqual(eqns, result0) self.assertEqual(outputs, result1)