def test_function_replace(self): par = Parameter(3, name='par') selector = Function( lambda *args: args[1], name='selector', parameters=( Parameter('bla', name='first'), par, Parameter('blup', name='third'), ), ) context = Parameter(2, name='pre_factor') * selector # store parents and children for comparison _children_before = selector.get_children() _parameters_before = selector.parameters _parents_before = selector.get_parents() # replace node selector_new = Function(lambda *args: args[2], name='selector_new') selector.replace(selector_new, other_children=False) # retrieve new parents and children for comparison _children_after = selector_new.get_children() _parents_after = selector_new.get_parents() _parameters_after = selector_new.parameters self.assertEqual(_children_before, _children_after) self.assertEqual(_parents_before, _parents_after) self.assertEqual(_parameters_before, _parameters_after) self.assertEqual(selector_new.value, 'blup')
def test_function_value_lambda(self): par_a = Parameter(3) par_b = Parameter(7) func_a_b = Function(lambda a, b: a * 10 + b, name='_lambda', parameters=(par_a, par_b)) self.assertEqual( func_a_b.value, func_a_b.func(par_a.value, par_b.value), )
def setUp(self): self.par_a = Parameter(3) self.par_b = Parameter(7) self.func_sum_a_b = Function(TestNodes.sum_function, parameters=(self.par_a, self.par_b)) self.empty_par = Empty() self.counter = 0 self.sum = 0 self.array_content = [0.1, 2.3, 4.5, 6.7, 8.9]
def test_function_auto_parameters(self): def func(a, b): return a * 10 + b func_a_b = Function(func) par_a = Parameter(3) par_b = Parameter(7) func_a_b.parameters = [par_a, par_b] self.assertEqual( func_a_b.parameters, [par_a, par_b], )
def test_parameter_replace(self): par = Parameter(3, name='par') selector = Function( lambda *args: args[1], name='selector', parameters=( Parameter('bla', name='first'), par, Parameter('blup', name='third'), ), ) context = Parameter(2, name='pre_factor') * selector # store parents and children for comparison _children_before = par.get_children() _parents_before = par.get_parents() # replace node par_new = Parameter(7, name='par_new') par.replace(par_new) par_new.value = 6 self.assertEqual(context.value, 12) # retrieve new parents and children for comparison _children_after = par_new.get_children() _parents_after = par_new.get_parents() self.assertEqual(_children_before, _children_after) self.assertEqual(_parents_before, _parents_after)
def test_function_auto_parameters_add_one_at_a_time(self): def func(a, b): return a * 10 + b func_a_b = Function(func) par_a = Parameter(3) par_b = Parameter(7) func_a_b.add_parameter(par_a) func_a_b.add_parameter(par_b) self.assertEqual( func_a_b.parameters, [par_a, par_b], ) self.assertEqual(func_a_b.value, 37)
def test_function_parameters(self): par_a = Parameter(3) par_b = Parameter(7) def func(a, b): return a * 10 + b func_a_b = Function(func, parameters=(par_a, par_b)) self.assertEqual( func_a_b.get_children(), [par_a, par_b], ) self.assertEqual( func_a_b.parameters, [par_a, par_b], )
def setUp(self): self.par_a = Parameter(3) self.par_b = Parameter(7) self.func_sum_a_b = Function(TestNodes.sum_function, parameters=(self.par_a, self.par_b)) self.empty_par = Empty()
def test_function_value_update_frozen(self): par_a = Parameter(3) par_b = Parameter(7) def func(a, b): return a * 10 + b func_a_b = Function(func, parameters=(par_a, par_b)) self.assertEqual( func_a_b.value, func(par_a.value, par_b.value), ) # test value update func_a_b.freeze() par_a.value = -3 self.assertEqual( func_a_b.value, func(3, 7), ) func_a_b.unfreeze() self.assertEqual( func_a_b.value, func(par_a.value, par_b.value), )
def setUp(self): self.root = RootNode() self.empty = Empty(name="empty") self.par_a = Parameter(2, name="par_a") self.par_b = Parameter(3, name="par_b") self.par_c = Parameter(4, name="par_c") self.func_1 = self.par_a + self.par_b self.func_2 = self.par_a * self.par_c self.func_3 = Function(lambda a, b: a + b) self.alias = Alias(self.par_a, name="alias") self.par_a.set_children([self.par_b]) self.root.set_children( [self.empty, self.func_1, self.func_2, self.func_3, self.alias])
def test_nodes_parametric_constructor(self): p = Parameter(None) str(p) a = Alias(Parameter(None)) str(a) f = Function(lambda x: x, name='_lambda') str(f) fb = Fallback((f, )) str(fb) t = Tuple([]) str(t)
def test_function_value(self): par_a = Parameter(3) par_b = Parameter(7) def func(a, b): return a * 10 + b func_a_b = Function(func, parameters=(par_a, par_b)) self.assertEqual( func_a_b.value, func(par_a.value, par_b.value), ) # test value update par_a.value = -3 self.assertEqual( func_a_b.value, func(par_a.value, par_b.value), ) with self.assertRaises(NodeException): func_a_b.value = 5
def test_parameter_replace_literal(self): par = Parameter(3, name='par') selector = Function( lambda *args: args[1], name='selector', parameters=( Parameter('bla', name='first'), par, Parameter('blup', name='third'), ), ) context = Parameter(2, name='pre_factor') * selector # replace node par.replace(7) self.assertEqual(context.value, 14)
class TestNodes(unittest.TestCase): DEFAULT_CONSTRUCTIBLE_NODE_TYPES = (NodeBase, Empty, RootNode) PARAMETRIC_NODE_TYPES = (Parameter, Alias, Function, Fallback, Tuple) @staticmethod def sum_function(a, b): return a + b def setUp(self): self.par_a = Parameter(3) self.par_b = Parameter(7) self.func_sum_a_b = Function(TestNodes.sum_function, parameters=(self.par_a, self.par_b)) self.empty_par = Empty() self.counter = 0 self.sum = 0 self.array_content = [0.1, 2.3, 4.5, 6.7, 8.9] # -- NodeBase def test_nodes_default_constructor(self): for _node_type in self.DEFAULT_CONSTRUCTIBLE_NODE_TYPES: # NodeBase constructor cannot be called because NodeBase has an abstract method. if _node_type is not NodeBase: _node_type() str(_node_type()) def test_nodes_parametric_constructor(self): p = Parameter(None) str(p) a = Alias(Parameter(None)) str(a) f = Function(lambda x: x, name='_lambda') str(f) fb = Fallback((f, )) str(fb) t = Tuple([]) str(t) def test_nodes_equality(self): p_1 = Parameter(None) p_2 = Parameter(None) self.assertEqual(p_1, p_1) self.assertNotEqual(p_1, p_2) def test_node_get_children_get_parents(self): par_1 = Parameter(3, name='par_1') par_2 = Parameter(2, name='par_2') par_lc_1 = par_1 + par_2 par_lc_2 = par_1 - par_2 for _par in par_1, par_2: self.assertEqual(_par.get_parents(), [par_lc_1, par_lc_2]) for _par in par_lc_1, par_lc_2: self.assertEqual(_par.get_children(), [par_1, par_2]) def test_node_iter_children_iter_parents(self): par_1 = Parameter(3, name='par_1') par_2 = Parameter(2, name='par_2') par_lc_1 = par_1 + par_2 par_lc_2 = par_1 - par_2 for _par in par_1, par_2: self.assertEqual([_parent for _parent in _par.iter_parents()], [par_lc_1, par_lc_2]) for _par in par_lc_1, par_lc_2: self.assertEqual([_child for _child in _par.iter_children()], [par_1, par_2]) def test_node_add_child(self): par = Parameter(3, name='par') par.add_child(1) par.add_child(2) self.assertEqual([p.value for p in par.iter_children()], [1, 2]) def test_node_add_parent(self): par = Parameter(3, name='par') par_2 = Parameter(2, name='par_2') with self.assertRaises(TypeError): par.add_parent("notanode") # Manually adding parents is not allowed: with self.assertRaises(NodeException): par.add_parent(par_2) par_2.add_child(par) self.assertEqual(par.get_parents(), [par_2]) def test_node_set_children(self): par = Parameter(3, name='par') child_1 = Parameter(1, name='child_1') child_2 = Parameter(2, name='child_2') par.set_children([child_1, child_2]) self.assertEqual(par.get_children(), [child_1, child_2]) def test_node_remove_child(self): par = Parameter(3, name='par') child_1 = Parameter(1, name='child_1') child_2 = Parameter(2, name='child_2') par.set_children([child_1, child_2]) par.remove_child(child_1) self.assertEqual(par.get_children(), [child_2]) with self.assertRaises(TypeError): par.remove_child("notanode") def test_node_remove_parent(self): par = Parameter(3, name='par') parent_1 = Parameter(1, name='parent_1') parent_2 = Parameter(2, name='parent_2') parent_1.add_child(par) parent_2.add_child(par) with self.assertRaises(NodeException): par.remove_parent(parent_1) self.assertEqual(par.get_parents(), [parent_1, parent_2]) parent_1.remove_child(par) self.assertEqual(par.get_parents(), [parent_2]) with self.assertRaises(TypeError): par.remove_parent("notanode") def test_parameter_replace(self): par = Parameter(3, name='par') selector = Function( lambda *args: args[1], name='selector', parameters=( Parameter('bla', name='first'), par, Parameter('blup', name='third'), ), ) context = Parameter(2, name='pre_factor') * selector # store parents and children for comparison _children_before = par.get_children() _parents_before = par.get_parents() # replace node par_new = Parameter(7, name='par_new') par.replace(par_new) par_new.value = 6 self.assertEqual(context.value, 12) # retrieve new parents and children for comparison _children_after = par_new.get_children() _parents_after = par_new.get_parents() self.assertEqual(_children_before, _children_after) self.assertEqual(_parents_before, _parents_after) def test_parameter_replace_literal(self): par = Parameter(3, name='par') selector = Function( lambda *args: args[1], name='selector', parameters=( Parameter('bla', name='first'), par, Parameter('blup', name='third'), ), ) context = Parameter(2, name='pre_factor') * selector # replace node par.replace(7) self.assertEqual(context.value, 14) def test_function_replace(self): par = Parameter(3, name='par') selector = Function( lambda *args: args[1], name='selector', parameters=( Parameter('bla', name='first'), par, Parameter('blup', name='third'), ), ) context = Parameter(2, name='pre_factor') * selector # store parents and children for comparison _children_before = selector.get_children() _parameters_before = selector.parameters _parents_before = selector.get_parents() # replace node selector_new = Function(lambda *args: args[2], name='selector_new') selector.replace(selector_new, other_children=False) # retrieve new parents and children for comparison _children_after = selector_new.get_children() _parents_after = selector_new.get_parents() _parameters_after = selector_new.parameters self.assertEqual(_children_before, _children_after) self.assertEqual(_parents_before, _parents_after) self.assertEqual(_parameters_before, _parameters_after) self.assertEqual(selector_new.value, 'blup') def test_callback(self): def increment_counter(): self.counter += 1 def update_sum(summand): self.sum += summand self.par_a.register_callback(lambda: 42) self.par_a.register_callback(increment_counter) self.func_sum_a_b.register_callback(update_sum, args=[10]) self.par_a.value = 5 self.assertEqual(self.counter, 1) self.assertEqual(self.sum, 10) self.par_a.value = 6 self.assertEqual(self.counter, 2) self.assertEqual(self.sum, 20) def test_replace_child(self): par_a = Parameter(4) par_b = Parameter(5) par_c = Parameter(6) par_a.add_child(par_b) self.assertEqual(par_a.get_children(), [par_b]) with self.assertRaises(NodeException): par_a.replace_child(current_child=par_c, new_child=par_b) par_a.replace_child(current_child=par_b, new_child=par_c) with self.assertRaises(NodeException): par_a.replace_child(current_child=par_b, new_child=par_c) self.assertEqual(par_a.get_children(), [par_c]) par_a.replace_child(current_child=par_c, new_child=7) self.assertNotEqual(par_a.get_children(), [par_b]) self.assertNotEqual(par_a.get_children(), [par_c]) self.assertIs(type(par_a.get_children()[0]), Parameter) # -- RootNode def test_root_node_replace(self): r = RootNode() with self.assertRaises(TypeError): r.replace(7) def test_root_node_parents(self): r = RootNode() with self.assertRaises(TypeError): r.add_parent(2) self.assertEqual([p for p in r.iter_parents()], []) self.assertEqual(r.get_parents(), []) # -- Empty def test_empty_value(self): e = Empty() with self.assertRaises(NodeException): _ = e.value with self.assertRaises(NodeException): e.value = 33 # -- Parameter def test_parameter_get_value(self): par = Parameter(3) self.assertEqual(par.value, 3) def test_parameter_set_value(self): par = Parameter(3) par.value = 7 self.assertEqual(par.value, 7) def test_parameter_name(self): par = Parameter(3, name='bla') self.assertEqual(par.name, 'bla') def test_parameter_invalid_name_raise(self): with self.assertRaises(NodeException): par = Parameter(3, name='2bla') def test_parameter_reserved_name_raise(self): for _rn in NodeBase.RESERVED_PARAMETER_NAMES: with self.assertRaises(NodeException): par = Parameter(3, name=_rn) def test_tuple_iter_raise(self): par = Parameter(3) with self.assertRaises(TypeError): iter(par) def test_parameter_binary_operation(self): par_a = Parameter(3) par_b = Parameter(7) for _op_name, (_op, _) in _OPERATORS.items(): par_expr = _op(par_a, par_b) self.assertEqual(par_expr.value, _op(par_a.value, par_b.value)) def test_parameter_binary_operation_with_literal(self): par_a = Parameter(3) for _op_name, (_op, _) in _OPERATORS.items(): par_expr = _op(par_a, 7) self.assertEqual(par_expr.value, _op(par_a.value, 7)) # reversed order par_expr = _op(7, par_a) self.assertEqual(par_expr.value, _op(7, par_a.value)) def test_parameter_unary_operation(self): par = Parameter(3) self.assertEqual((+par).value, par.value) self.assertEqual((-par).value, -(par.value)) self.assertEqual((~par).value, ~(par.value)) # -- Alias def test_alias_value(self): par = Parameter(3) alias = Alias(par, name='alias') self.assertEqual(alias.value, par.value) with self.assertRaises(NodeException): alias.value = 5 def test_multiple_alias_value(self): par = Parameter(3) alias = Alias(Alias(Alias(par))) self.assertEqual(alias.value, par.value) def test_alias_ref(self): par = Parameter(3) alias = Alias(par, name='alias') self.assertEqual(alias.ref, par) # -- Function def test_function_value(self): par_a = Parameter(3) par_b = Parameter(7) def func(a, b): return a * 10 + b func_a_b = Function(func, parameters=(par_a, par_b)) self.assertEqual( func_a_b.value, func(par_a.value, par_b.value), ) # test value update par_a.value = -3 self.assertEqual( func_a_b.value, func(par_a.value, par_b.value), ) with self.assertRaises(NodeException): func_a_b.value = 5 def test_function_value_update_frozen(self): par_a = Parameter(3) par_b = Parameter(7) def func(a, b): return a * 10 + b func_a_b = Function(func, parameters=(par_a, par_b)) self.assertEqual( func_a_b.value, func(par_a.value, par_b.value), ) # test value update func_a_b.freeze() par_a.value = -3 self.assertEqual( func_a_b.value, func(3, 7), ) func_a_b.unfreeze() self.assertEqual( func_a_b.value, func(par_a.value, par_b.value), ) def test_function_value_lambda(self): par_a = Parameter(3) par_b = Parameter(7) func_a_b = Function(lambda a, b: a * 10 + b, name='_lambda', parameters=(par_a, par_b)) self.assertEqual( func_a_b.value, func_a_b.func(par_a.value, par_b.value), ) def test_function_parameters(self): par_a = Parameter(3) par_b = Parameter(7) def func(a, b): return a * 10 + b func_a_b = Function(func, parameters=(par_a, par_b)) self.assertEqual( func_a_b.get_children(), [par_a, par_b], ) self.assertEqual( func_a_b.parameters, [par_a, par_b], ) def test_function_auto_parameters(self): def func(a, b): return a * 10 + b func_a_b = Function(func) par_a = Parameter(3) par_b = Parameter(7) func_a_b.parameters = [par_a, par_b] self.assertEqual( func_a_b.parameters, [par_a, par_b], ) self.assertEqual(func_a_b.value, 37) def test_function_auto_parameters_add_one_at_a_time(self): def func(a, b): return a * 10 + b func_a_b = Function(func) par_a = Parameter(3) par_b = Parameter(7) func_a_b.add_parameter(par_a) func_a_b.add_parameter(par_b) self.assertEqual( func_a_b.parameters, [par_a, par_b], ) self.assertEqual(func_a_b.value, 37) # -- Fallback def test_fallback(self): par_a = Parameter(14) par_b = Parameter(2) div_a_b = Fallback((par_a / par_b, 'DIV0'), exception_type=ZeroDivisionError) # no exception yet self.assertEqual( div_a_b.value, par_a.value / par_b.value, ) # make denominator zero par_b.value = 0 # exception matches type -> fallback self.assertEqual( div_a_b.value, 'DIV0', ) # exception does not match type -> raise with self.assertRaises(ZeroDivisionError): Fallback((par_a / par_b, 'DIV0'), exception_type=TypeError).value def test_fallback_no_good_alternative(self): par_a = Parameter(14) div_a_b = Fallback((par_a / 0, 2 * par_a / 0), exception_type=ZeroDivisionError) # no exception yet with self.assertRaises(FallbackError): div_a_b.value, # -- Tuple def test_tuple_value(self): tuple_a = Tuple((14, 2)) self.assertIs(type(tuple_a.value), tuple) self.assertEqual(tuple_a.value, (14, 2)) def test_tuple_len(self): tuple_a = Tuple((14, 2)) self.assertEqual(len(tuple_a), 2) def test_tuple_getitem(self): tuple_a = Tuple((14, 2)) self.assertEqual(tuple_a.value, (14, 2)) self.assertEqual(tuple_a[0].value, 14) self.assertEqual(tuple_a[1].value, 2) def test_tuple_setitem(self): tuple_a = Tuple((14, 2)) tuple_a[0] = Parameter(15) tuple_a[1] = 1 self.assertEqual(tuple_a.value, (15, 1)) self.assertEqual(tuple_a[0].value, 15) self.assertEqual(tuple_a[1].value, 1) def test_tuple_nodes(self): tuple_a = Tuple((14, 2)) tuple_a.nodes = (9, 4) self.assertEqual(tuple_a.value, (9, 4)) def test_tuple_iter(self): tuple_a = Tuple((14, 2)) self.assertEqual(tuple_a.nodes, [node for node in tuple_a]) def test_tuple_iter_values(self): tuple_a = Tuple((14, 2)) self.assertEqual(tuple_a.value, tuple([value for value in tuple_a.iter_values()])) # -- Array def test_array_value(self): array_a = Array(self.array_content) self.assertIs(type(array_a.value), np.ndarray) self.assertTrue(np.all(array_a.value == self.array_content)) def test_array_getitem(self): array_a = Array(self.array_content) for i, expected_value_i in enumerate(self.array_content): self.assertEqual(array_a[i].value, expected_value_i) def test_array_setitem(self): array_a = Array(self.array_content) array_a[3] = 0 self.assertIs(type(array_a[3]), Parameter) self.assertTrue(np.all(array_a.value == [0.1, 2.3, 4.5, 0.0, 8.9])) def test_array_iter_values(self): array_a = Array(self.array_content) self.assertTrue( np.all(array_a.value == np.array( [value for value in array_a.iter_values()])))
def test_nodes_parametric_constructor(self): Parameter(None) Alias(Parameter(None)) f = Function(lambda x: x, name='_lambda') Fallback((f, )) Tuple([])