def assert_log_array_almost_equal(array_1, array_2): """ WRITEME """ # if it is a numpy array, then flatten it try: array_1 = array_1.flatten() except: try: array_1 = flatten_list(array_1) except: pass try: array_2 = array_2.flatten() except: try: array_2 = flatten_list(array_2) except: pass for elem_1, elem_2 in zip(array_1, array_2): if elem_1 == LOG_ZERO: # -2000 == -1000 since exp(-2000) == exp(-1000) assert IS_LOG_ZERO(elem_2) is True else: assert_almost_equal(elem_1, elem_2, PRECISION)
def test_product_layer_create_and_eval(): # creating generic nodes node1 = Node() node2 = Node() node3 = Node() # whose values are val1 = 0.8 val2 = 1. val3 = 0. node1.set_val(val1) node2.set_val(val2) node3.set_val(val3) # creating product nodes prod1 = ProductNode() prod2 = ProductNode() prod3 = ProductNode() # adding children prod1.add_child(node1) prod1.add_child(node2) prod2.add_child(node1) prod2.add_child(node3) prod3.add_child(node2) prod3.add_child(node3) # adding product nodes to layer product_layer = ProductLayer([prod1, prod2, prod3]) # evaluating product_layer.eval() # getting log vals layer_evals = product_layer.node_values() print('layer eval nodes') print(layer_evals) # computing our values prodval1 = val1 * val2 logval1 = log(prodval1) if prodval1 > 0. else LOG_ZERO prodval2 = val1 * val3 logval2 = log(prodval2) if prodval2 > 0. else LOG_ZERO prodval3 = val2 * val3 logval3 = log(prodval3) if prodval3 > 0. else LOG_ZERO logvals = [logval1, logval2, logval3] print('log vals') print(logvals) for logval, eval in zip(logvals, layer_evals): if logval == LOG_ZERO: # for zero log check this way for correctness assert IS_LOG_ZERO(eval) is True else: assert_almost_equal(logval, eval, PRECISION)
def test_categorical_input_layer(): print('categorical input layer') # I could loop through alpha as well alpha = 0.1 for var_id1 in range(len(vars)): for var_id2 in range(len(vars)): for var_val1 in range(vars[var_id1]): print('varid1, varid2, varval1', var_id1, var_id2, var_val1) # var_id1 = 0 # var_val1 = 0 node1 = CategoricalIndicatorNode(var_id1, var_val1) # var_id2 = 0 var_vals2 = vars[var_id2] node2 = CategoricalSmoothedNode( var_id2, var_vals2, alpha, freqs[var_id2]) # creating the generic input layer input_layer = CategoricalInputLayer([node1, node2]) # evaluating according to an observation input_layer.eval(obs) layer_evals = input_layer.node_values() print('layer eval nodes') print(layer_evals) # computing evaluation by hand val1 = 1 if var_val1 == obs[var_id1] or obs[ var_id1] == MARG_IND else 0 logval1 = log(val1) if val1 == 1 else LOG_ZERO logval2 = compute_smoothed_ll( obs[var_id2], freqs[var_id2], vars[var_id2], alpha) logvals = [logval1, logval2] print('log vals') print(logvals) for logval, eval in zip(logvals, layer_evals): if logval == LOG_ZERO: # for zero log check this way for correctness assert IS_LOG_ZERO(eval) is True else: assert_almost_equal(logval, eval, PRECISION)
def test_product_node_create_and_eval(): # create child nodes child1 = Node() val1 = 1. child1.set_val(val1) child2 = Node() val2 = 1. child2.set_val(val2) # create product node and add children prod_node = ProductNode() prod_node.add_child(child1) prod_node.add_child(child2) assert len(prod_node.children) == 2 print(prod_node) # evaluation prod_node.eval() print(prod_node.log_val) assert_almost_equal(prod_node.log_val, log(val1 * val2), places=15) # changing values 0,1 -> LOG_ZERO val1 = 0. child1.set_val(val1) val2 = 1. child2.set_val(val2) prod_node.eval() print(prod_node.log_val) assert_almost_equal(prod_node.log_val, LOG_ZERO, places=15) # changing values 0,1 -> LOG_ZERO val1 = 0. child1.set_val(val1) val2 = 0. child2.set_val(val2) prod_node.eval() print(prod_node.log_val) # now testing with macro since -1000 + -1000 != -1000 assert IS_LOG_ZERO(prod_node.log_val) is True
def test_sum_layer_backprop(): # input layer made of 5 generic nodes node1 = Node() node2 = Node() node3 = Node() node4 = Node() node5 = Node() # top layer made by 3 sum nodes sum1 = SumNode() sum2 = SumNode() sum3 = SumNode() # linking to input nodes weight11 = 0.3 sum1.add_child(node1, weight11) weight12 = 0.3 sum1.add_child(node2, weight12) weight13 = 0.4 sum1.add_child(node3, weight13) weight22 = 0.15 sum2.add_child(node2, weight22) weight23 = 0.15 sum2.add_child(node3, weight23) weight24 = 0.7 sum2.add_child(node4, weight24) weight33 = 0.4 sum3.add_child(node3, weight33) weight34 = 0.25 sum3.add_child(node4, weight34) weight35 = 0.35 sum3.add_child(node5, weight35) sum_layer = SumLayer([sum1, sum2, sum3]) # setting input values val1 = 0.0 node1.set_val(val1) val2 = 0.5 node2.set_val(val2) val3 = 0.3 node3.set_val(val3) val4 = 1.0 node4.set_val(val4) val5 = 0.0 node5.set_val(val5) # evaluating sum_layer.eval() print('eval\'d layer:', sum_layer.node_values()) # set the parent derivatives sum_der1 = 1.0 sum1.log_der = log(sum_der1) sum_der2 = 1.0 sum2.log_der = log(sum_der2) sum_der3 = 0.0 sum3.log_der = LOG_ZERO # back prop layer wise sum_layer.backprop() # check for correctness try: log_der1 = log(sum_der1 * weight11) except: log_der1 = LOG_ZERO try: log_der2 = log(sum_der1 * weight12 + sum_der2 * weight22) except: log_der2 = LOG_ZERO try: log_der3 = log(sum_der1 * weight13 + sum_der2 * weight23 + sum_der3 * weight33) except: log_der3 = LOG_ZERO try: log_der4 = log(sum_der2 * weight24 + sum_der3 * weight34) except: log_der4 = LOG_ZERO try: log_der5 = log(sum_der3 * weight35) except: log_der5 = LOG_ZERO # printing, just in case print('child log der', node1.log_der, node2.log_der, node3.log_der, node4.log_der, node5.log_der) print('exact log der', log_der1, log_der2, log_der3, log_der4, log_der5) if IS_LOG_ZERO(log_der1): assert IS_LOG_ZERO(node1.log_der) else: assert_almost_equal(log_der1, node1.log_der, 15) if IS_LOG_ZERO(log_der2): assert IS_LOG_ZERO(node2.log_der) else: assert_almost_equal(log_der2, node2.log_der, 15) if IS_LOG_ZERO(log_der3): assert IS_LOG_ZERO(node3.log_der) else: assert_almost_equal(log_der3, node3.log_der, 15) if IS_LOG_ZERO(log_der4): assert IS_LOG_ZERO(node4.log_der) else: assert_almost_equal(log_der4, node4.log_der, 15) if IS_LOG_ZERO(log_der5): assert IS_LOG_ZERO(node5.log_der) else: assert_almost_equal(log_der5, node5.log_der, 15) # updating weights eta = 0.1 sum_layer.update_weights(Spn.test_weight_update, 0) # checking for correctness weight_u11 = sum_der1 * val1 * eta + weight11 weight_u12 = sum_der1 * val2 * eta + weight12 weight_u13 = sum_der1 * val3 * eta + weight13 weight_u22 = sum_der2 * val2 * eta + weight22 weight_u23 = sum_der2 * val3 * eta + weight23 weight_u24 = sum_der2 * val4 * eta + weight24 weight_u33 = sum_der3 * val3 * eta + weight33 weight_u34 = sum_der3 * val4 * eta + weight34 weight_u35 = sum_der3 * val5 * eta + weight35 # normalizing weight_sum1 = weight_u11 + weight_u12 + weight_u13 weight_sum2 = weight_u22 + weight_u23 + weight_u24 weight_sum3 = weight_u33 + weight_u34 + weight_u35 weight_u11 = weight_u11 / weight_sum1 weight_u12 = weight_u12 / weight_sum1 weight_u13 = weight_u13 / weight_sum1 weight_u22 = weight_u22 / weight_sum2 weight_u23 = weight_u23 / weight_sum2 weight_u24 = weight_u24 / weight_sum2 weight_u33 = weight_u33 / weight_sum3 weight_u34 = weight_u34 / weight_sum3 weight_u35 = weight_u35 / weight_sum3 print('expected weights', weight_u11, weight_u12, weight_u13, weight_u22, weight_u23, weight_u24, weight_u33, weight_u34, weight_u35) print('found weights', sum1.weights[0], sum1.weights[1], sum1.weights[2], sum2.weights[0], sum2.weights[1], sum2.weights[2], sum3.weights[0], sum3.weights[1], sum3.weights[2]) assert_almost_equal(weight_u11, sum1.weights[0], 10) assert_almost_equal(weight_u12, sum1.weights[1], 10) assert_almost_equal(weight_u13, sum1.weights[2], 10) assert_almost_equal(weight_u22, sum2.weights[0], 10) assert_almost_equal(weight_u23, sum2.weights[1], 10) assert_almost_equal(weight_u24, sum2.weights[2], 10) assert_almost_equal(weight_u33, sum3.weights[0], 10) assert_almost_equal(weight_u34, sum3.weights[1], 10) assert_almost_equal(weight_u35, sum3.weights[2], 10) # # resetting derivatives # node1.log_der = LOG_ZERO node2.log_der = LOG_ZERO node3.log_der = LOG_ZERO node4.log_der = LOG_ZERO node5.log_der = LOG_ZERO # setting new values as inputs val1 = 0.0 node1.set_val(val1) val2 = 0.0 node2.set_val(val2) val3 = 0.3 node3.set_val(val3) val4 = 1.0 node4.set_val(val4) val5 = 1.0 node5.set_val(val5) # evaluating again sum_layer.eval() print('eval\'d layer:', sum_layer.node_values()) # set the parent derivatives sum_der1 = 1.0 sum1.log_der = log(sum_der1) sum_der2 = 1.0 sum2.log_der = log(sum_der2) sum_der3 = 0.0 sum3.log_der = LOG_ZERO # back prop layer wise sum_layer.backprop() # check for correctness try: log_der1 = log(sum_der1 * weight_u11) except: log_der1 = LOG_ZERO try: log_der2 = log(sum_der1 * weight_u12 + sum_der2 * weight_u22) except: log_der2 = LOG_ZERO try: log_der3 = log(sum_der1 * weight_u13 + sum_der2 * weight_u23 + sum_der3 * weight_u33) except: log_der3 = LOG_ZERO try: log_der4 = log(sum_der2 * weight_u24 + sum_der3 * weight_u34) except: log_der4 = LOG_ZERO try: log_der5 = log(sum_der3 * weight_u35) except: log_der5 = LOG_ZERO # printing, just in case print('child log der', node1.log_der, node2.log_der, node3.log_der, node4.log_der, node5.log_der) print('exact log der', log_der1, log_der2, log_der3, log_der4, log_der5) if IS_LOG_ZERO(log_der1): assert IS_LOG_ZERO(node1.log_der) else: assert_almost_equal(log_der1, node1.log_der, 15) if IS_LOG_ZERO(log_der2): assert IS_LOG_ZERO(node2.log_der) else: assert_almost_equal(log_der2, node2.log_der, 15) if IS_LOG_ZERO(log_der3): assert IS_LOG_ZERO(node3.log_der) else: assert_almost_equal(log_der3, node3.log_der, 15) if IS_LOG_ZERO(log_der4): assert IS_LOG_ZERO(node4.log_der) else: assert_almost_equal(log_der4, node4.log_der, 15) if IS_LOG_ZERO(log_der5): assert IS_LOG_ZERO(node5.log_der) else: assert_almost_equal(log_der5, node5.log_der, 15)
def test_prod_layer_backprop(): # input layer made of 5 generic nodes node1 = Node() node2 = Node() node3 = Node() node4 = Node() node5 = Node() input_layer = CategoricalInputLayer([node1, node2, node3, node4, node5]) # top layer made by 3 prod nodes prod1 = ProductNode() prod2 = ProductNode() prod3 = ProductNode() # linking to input nodes prod1.add_child(node1) prod1.add_child(node2) prod1.add_child(node3) prod2.add_child(node2) prod2.add_child(node3) prod2.add_child(node4) prod3.add_child(node3) prod3.add_child(node4) prod3.add_child(node5) prod_layer = ProductLayer([prod1, prod2, prod3]) # setting input values val1 = 0.0 node1.set_val(val1) val2 = 0.5 node2.set_val(val2) val3 = 0.3 node3.set_val(val3) val4 = 1.0 node4.set_val(val4) val5 = 0.0 node5.set_val(val5) print('input', [node.log_val for node in input_layer.nodes()]) # evaluating prod_layer.eval() print('eval\'d layer:', prod_layer.node_values()) # set the parent derivatives prod_der1 = 1.0 prod1.log_der = log(prod_der1) prod_der2 = 1.0 prod2.log_der = log(prod_der2) prod_der3 = 0.0 prod3.log_der = LOG_ZERO # back prop layer wise prod_layer.backprop() # check for correctness try: log_der1 = log(prod_der1 * val2 * val3) except: log_der1 = LOG_ZERO try: log_der2 = log(prod_der1 * val1 * val3 + prod_der2 * val3 * val4) except: log_der2 = LOG_ZERO try: log_der3 = log(prod_der2 * val2 * val4 + prod_der3 * val4 * val5 + prod_der1 * val1 * val2) except: log_der3 = LOG_ZERO try: log_der4 = log(prod_der2 * val2 * val3 + prod_der3 * val3 * val5) except: log_der4 = LOG_ZERO try: log_der5 = log(prod_der3 * val3 * val4) except: log_der5 = LOG_ZERO # printing, just in case print('child log der', node1.log_der, node2.log_der, node3.log_der, node4.log_der, node5.log_der) print('exact log der', log_der1, log_der2, log_der3, log_der4, log_der5) if IS_LOG_ZERO(log_der1): assert IS_LOG_ZERO(node1.log_der) else: assert_almost_equal(log_der1, node1.log_der, 15) if IS_LOG_ZERO(log_der2): assert IS_LOG_ZERO(node2.log_der) else: assert_almost_equal(log_der2, node2.log_der, 15) if IS_LOG_ZERO(log_der3): assert IS_LOG_ZERO(node3.log_der) else: assert_almost_equal(log_der3, node3.log_der, 15) if IS_LOG_ZERO(log_der4): assert IS_LOG_ZERO(node4.log_der) else: assert_almost_equal(log_der4, node4.log_der, 15) if IS_LOG_ZERO(log_der5): assert IS_LOG_ZERO(node5.log_der) else: assert_almost_equal(log_der5, node5.log_der, 15) # resetting derivatives node1.log_der = LOG_ZERO node2.log_der = LOG_ZERO node3.log_der = LOG_ZERO node4.log_der = LOG_ZERO node5.log_der = LOG_ZERO # setting new values as inputs val1 = 0.0 node1.set_val(val1) val2 = 0.0 node2.set_val(val2) val3 = 0.3 node3.set_val(val3) val4 = 1.0 node4.set_val(val4) val5 = 1.0 node5.set_val(val5) # evaluating again prod_layer.eval() print('eval\'d layer:', prod_layer.node_values()) # set the parent derivatives prod_der1 = 1.0 prod1.log_der = log(prod_der1) prod_der2 = 1.0 prod2.log_der = log(prod_der2) prod_der3 = 0.0 prod3.log_der = LOG_ZERO # back prop layer wise prod_layer.backprop() # check for correctness try: log_der1 = log(prod_der1 * val2 * val3) except: log_der1 = LOG_ZERO try: log_der2 = log(prod_der1 * val1 * val3 + prod_der2 * val3 * val4) except: log_der2 = LOG_ZERO try: log_der3 = log(prod_der2 * val2 * val4 + prod_der3 * val4 * val5 + prod_der1 * val1 * val2) except: log_der3 = LOG_ZERO try: log_der4 = log(prod_der2 * val2 * val3 + prod_der3 * val3 * val5) except: log_der4 = LOG_ZERO try: log_der5 = log(prod_der3 * val3 * val4) except: log_der5 = LOG_ZERO # printing, just in case print('child log der', node1.log_der, node2.log_der, node3.log_der, node4.log_der, node5.log_der) print('exact log der', log_der1, log_der2, log_der3, log_der4, log_der5) if IS_LOG_ZERO(log_der1): assert IS_LOG_ZERO(node1.log_der) else: assert_almost_equal(log_der1, node1.log_der, 15) if IS_LOG_ZERO(log_der2): assert IS_LOG_ZERO(node2.log_der) else: assert_almost_equal(log_der2, node2.log_der, 15) if IS_LOG_ZERO(log_der3): assert IS_LOG_ZERO(node3.log_der) else: assert_almost_equal(log_der3, node3.log_der, 15) if IS_LOG_ZERO(log_der4): assert IS_LOG_ZERO(node4.log_der) else: assert_almost_equal(log_der4, node4.log_der, 15) if IS_LOG_ZERO(log_der5): assert IS_LOG_ZERO(node5.log_der) else: assert_almost_equal(log_der5, node5.log_der, 15)
def test_spn_mpe_eval_and_traversal(): # create initial layer node1 = Node() node2 = Node() node3 = Node() node4 = Node() node5 = Node() input_layer = CategoricalInputLayer([node1, node2, node3, node4, node5]) # top layer made by 3 sum nodes sum1 = SumNode() sum2 = SumNode() sum3 = SumNode() # linking to input nodes weight11 = 0.3 sum1.add_child(node1, weight11) weight12 = 0.3 sum1.add_child(node2, weight12) weight13 = 0.4 sum1.add_child(node3, weight13) weight22 = 0.15 sum2.add_child(node2, weight22) weight23 = 0.15 sum2.add_child(node3, weight23) weight24 = 0.7 sum2.add_child(node4, weight24) weight33 = 0.4 sum3.add_child(node3, weight33) weight34 = 0.25 sum3.add_child(node4, weight34) weight35 = 0.35 sum3.add_child(node5, weight35) sum_layer = SumLayer([sum1, sum2, sum3]) # another layer with two product nodes prod1 = ProductNode() prod2 = ProductNode() prod1.add_child(sum1) prod1.add_child(sum2) prod2.add_child(sum2) prod2.add_child(sum3) prod_layer = ProductLayer([prod1, prod2]) # root layer, double sum root1 = SumNode() root2 = SumNode() weightr11 = 0.5 root1.add_child(prod1, weightr11) weightr12 = 0.5 root1.add_child(prod2, weightr12) weightr21 = 0.9 root2.add_child(prod1, weightr21) weightr22 = 0.1 root2.add_child(prod2, weightr22) root_layer = SumLayer([root1, root2]) # create the spn spn = Spn(input_layer=input_layer, layers=[sum_layer, prod_layer, root_layer]) print('===================') print(spn) print('===================') # setting the input values val1 = 0.0 node1.set_val(val1) val2 = 0.5 node2.set_val(val2) val3 = 0.3 node3.set_val(val3) val4 = 1.0 node4.set_val(val4) val5 = 0.0 node5.set_val(val5) # evaluating the spn with MPE inference res = spn.test_mpe_eval() print('spn eval\'d', res) # testing it # # testing the max layer max1 = max(val1 * weight11, val2 * weight12, val3 * weight13) max2 = max(val2 * weight22, val3 * weight23, val4 * weight24) max3 = max(val3 * weight33, val4 * weight34, val5 * weight35) log_max1 = log(max1) if not numpy.isclose(max1, 0) else LOG_ZERO log_max2 = log(max2) if not numpy.isclose(max2, 0) else LOG_ZERO log_max3 = log(max3) if not numpy.isclose(max3, 0) else LOG_ZERO print('expected max vals {0}, {1}, {2}'.format(log_max1, log_max2, log_max3)) print('found max vals {0}, {1}, {2}'.format(sum1.log_val, sum2.log_val, sum3.log_val)) if IS_LOG_ZERO(log_max1): assert IS_LOG_ZERO(sum1.log_val) else: assert_almost_equal(log_max1, sum1.log_val) if IS_LOG_ZERO(log_max2): assert IS_LOG_ZERO(sum2.log_val) else: assert_almost_equal(log_max2, sum2.log_val) if IS_LOG_ZERO(log_max3): assert IS_LOG_ZERO(sum3.log_val) else: assert_almost_equal(log_max3, sum3.log_val) # product layer is assumed to be fine, but let's check # it anyways prod_val1 = max1 * max2 prod_val2 = max2 * max3 prod_log_val1 = log_max1 + log_max2 prod_log_val2 = log_max2 + log_max3 print('exp prod vals {0}, {1}'.format(prod_log_val1, prod_log_val2)) print('rea prod vals {0}, {1}'.format(prod1.log_val, prod2.log_val)) if IS_LOG_ZERO(prod_log_val1): assert IS_LOG_ZERO(prod1.log_val) else: assert_almost_equal(prod_log_val1, prod1.log_val) if IS_LOG_ZERO(prod_log_val2): assert IS_LOG_ZERO(prod2.log_val) else: assert_almost_equal(prod_log_val2, prod2.log_val) # root layer, again a sum layer root_val1 = max(prod_val1 * weightr11, prod_val2 * weightr12) root_val2 = max(prod_val1 * weightr21, prod_val2 * weightr22) root_log_val1 = log(root_val1) if not numpy.isclose( root_val1, 0) else LOG_ZERO root_log_val2 = log(root_val2) if not numpy.isclose( root_val2, 0) else LOG_ZERO print('exp root vals {0}, {1}'.format(root_log_val1, root_log_val2)) print('found ro vals {0}, {1}'.format(root1.log_val, root2.log_val)) if IS_LOG_ZERO(root_log_val1): assert IS_LOG_ZERO(root1.log_val) else: assert_almost_equal(root_log_val1, root1.log_val) if IS_LOG_ZERO(root_log_val2): assert IS_LOG_ZERO(root2.log_val) else: assert_almost_equal(root_log_val2, root2.log_val) # now we are traversing top down the net print('mpe traversing') for i, j, k in spn.mpe_traversal(): print(i, j, k)
def test_spn_backprop(): # create initial layer node1 = Node() node2 = Node() node3 = Node() node4 = Node() node5 = Node() input_layer = CategoricalInputLayer([node1, node2, node3, node4, node5]) # top layer made by 3 sum nodes sum1 = SumNode() sum2 = SumNode() sum3 = SumNode() # linking to input nodes weight11 = 0.3 sum1.add_child(node1, weight11) weight12 = 0.3 sum1.add_child(node2, weight12) weight13 = 0.4 sum1.add_child(node3, weight13) weight22 = 0.15 sum2.add_child(node2, weight22) weight23 = 0.15 sum2.add_child(node3, weight23) weight24 = 0.7 sum2.add_child(node4, weight24) weight33 = 0.4 sum3.add_child(node3, weight33) weight34 = 0.25 sum3.add_child(node4, weight34) weight35 = 0.35 sum3.add_child(node5, weight35) sum_layer = SumLayer([sum1, sum2, sum3]) # another layer with two product nodes prod1 = ProductNode() prod2 = ProductNode() prod1.add_child(sum1) prod1.add_child(sum2) prod2.add_child(sum2) prod2.add_child(sum3) prod_layer = ProductLayer([prod1, prod2]) # root layer, double sum root1 = SumNode() root2 = SumNode() weightr11 = 0.5 root1.add_child(prod1, weightr11) weightr12 = 0.5 root1.add_child(prod2, weightr12) weightr21 = 0.9 root2.add_child(prod1, weightr21) weightr22 = 0.1 root2.add_child(prod2, weightr22) root_layer = SumLayer([root1, root2]) # root_layer = SumLayer([root1]) # create the spn spn = Spn(input_layer=input_layer, layers=[sum_layer, prod_layer, root_layer]) # setting the input values val1 = 0.0 node1.set_val(val1) val2 = 0.5 node2.set_val(val2) val3 = 0.3 node3.set_val(val3) val4 = 1.0 node4.set_val(val4) val5 = 0.0 node5.set_val(val5) # evaluating the spn res = spn.test_eval() print('spn eval\'d', res) # backprop spn.backprop() # computing derivatives by hand # topdown: root layer root_der = 1.0 log_root_der = log(root_der) # print('root ders', root1.log_der, root2.log_der) print('root ders', root1.log_der) assert_almost_equal(log_root_der, root1.log_der) assert_almost_equal(log_root_der, root2.log_der) # product layer prod_der1 = (root_der * weightr11 + root_der * weightr21) prod_der2 = (root_der * weightr12 + root_der * weightr22) # prod_der1 = (root_der * weightr11) # prod_der2 = (root_der * weightr12) log_prod_der1 = log(prod_der1) if prod_der1 > 0.0 else LOG_ZERO log_prod_der2 = log(prod_der2) if prod_der2 > 0.0 else LOG_ZERO print('found prod ders', prod1.log_der, prod2.log_der) print('expect prod ders', log_prod_der1, log_prod_der2) if IS_LOG_ZERO(log_prod_der1): assert IS_LOG_ZERO(prod1.log_der) else: assert_almost_equal(log_prod_der1, prod1.log_der) if IS_LOG_ZERO(log_prod_der2): assert IS_LOG_ZERO(prod2.log_der) else: assert_almost_equal(log_prod_der2, prod2.log_der) # sum layer sum_der1 = ( prod_der1 * (weight22 * val2 + weight23 * val3 + weight24 * val4)) log_sum_der1 = log(sum_der1) if sum_der1 > 0.0 else LOG_ZERO sum_der2 = (prod_der1 * (weight11 * val1 + weight12 * val2 + weight13 * val3) + prod_der2 * (weight33 * val3 + weight34 * val4 + weight35 * val5)) log_sum_der2 = log(sum_der2) if sum_der2 > 0.0 else LOG_ZERO sum_der3 = (prod_der2 * (weight22 * val2 + weight23 * val3 + weight24 * val4)) log_sum_der3 = log(sum_der3) if sum_der3 > 0.0 else LOG_ZERO print('expected sum ders', log_sum_der1, log_sum_der2, log_sum_der3) print('found sum ders', sum1.log_der, sum2.log_der, sum3.log_der) if IS_LOG_ZERO(log_sum_der1): assert IS_LOG_ZERO(sum1.log_der) else: assert_almost_equal(log_sum_der1, sum1.log_der) if IS_LOG_ZERO(log_sum_der2): assert IS_LOG_ZERO(sum2.log_der) else: assert_almost_equal(log_sum_der2, sum2.log_der) if IS_LOG_ZERO(log_sum_der3): assert IS_LOG_ZERO(sum3.log_der) else: assert_almost_equal(log_sum_der3, sum3.log_der) # final level, the first one try: log_der1 = log(sum_der1 * weight11) except: log_der1 = LOG_ZERO try: log_der2 = log(sum_der1 * weight12 + sum_der2 * weight22) except: log_der2 = LOG_ZERO try: log_der3 = log(sum_der1 * weight13 + sum_der2 * weight23 + sum_der3 * weight33) except: log_der3 = LOG_ZERO try: log_der4 = log(sum_der2 * weight24 + sum_der3 * weight34) except: log_der4 = LOG_ZERO try: log_der5 = log(sum_der3 * weight35) except: log_der5 = LOG_ZERO # printing, just in case print('child log der', node1.log_der, node2.log_der, node3.log_der, node4.log_der, node5.log_der) print('exact log der', log_der1, log_der2, log_der3, log_der4, log_der5) if IS_LOG_ZERO(log_der1): assert IS_LOG_ZERO(node1.log_der) else: assert_almost_equal(log_der1, node1.log_der, 15) if IS_LOG_ZERO(log_der2): assert IS_LOG_ZERO(node2.log_der) else: assert_almost_equal(log_der2, node2.log_der, 15) if IS_LOG_ZERO(log_der3): assert IS_LOG_ZERO(node3.log_der) else: assert_almost_equal(log_der3, node3.log_der, 15) if IS_LOG_ZERO(log_der4): assert IS_LOG_ZERO(node4.log_der) else: assert_almost_equal(log_der4, node4.log_der, 15) if IS_LOG_ZERO(log_der5): assert IS_LOG_ZERO(node5.log_der) else: assert_almost_equal(log_der5, node5.log_der, 15)
def test_product_node_backprop(): # create child nodes child1 = Node() val1 = 1. child1.set_val(val1) child2 = Node() val2 = 1. child2.set_val(val2) child3 = Node() val3 = 0.0 child3.set_val(val3) # create a product node and add children prod_node1 = ProductNode() prod_node1.add_child(child1) prod_node1.add_child(child2) # create a second node on all children prod_node2 = ProductNode() prod_node2.add_child(child1) prod_node2.add_child(child2) prod_node2.add_child(child3) # eval prod_node1.eval() prod_node2.eval() # set der and backprop prod_node_der1 = 1.0 prod_node1.log_der = log(prod_node_der1) prod_node1.backprop() prod_node_der2 = 1.0 prod_node2.log_der = log(prod_node_der2) prod_node2.backprop() # check for correctness log_der1 = log(prod_node_der1 * val2 + prod_node_der2 * val2 * val3) log_der2 = log(prod_node_der1 * val1 + prod_node_der2 * val1 * val3) log_der3 = log(prod_node_der2 * val1 * val2) print('log ders 1:{lgd1} 2:{lgd2} 3:{lgd3}'.format(lgd1=log_der1, lgd2=log_der2, lgd3=log_der3)) assert_almost_equal(log_der1, child1.log_der, 15) assert_almost_equal(log_der2, child2.log_der, 15) assert_almost_equal(log_der3, child3.log_der, 15) # setting different values for children val1 = 0. child1.set_val(val1) val2 = 0. child2.set_val(val2) val3 = 1. child3.set_val(val3) # eval prod_node1.eval() prod_node2.eval() child1.log_der = LOG_ZERO child2.log_der = LOG_ZERO child3.log_der = LOG_ZERO # set der and backprop prod_node_der1 = 0.5 prod_node1.log_der = log(prod_node_der1) prod_node1.backprop() prod_node_der2 = 0.1 prod_node2.log_der = log(prod_node_der2) prod_node2.backprop() # check for correctness try: log_der1 = log(prod_node_der1 * val2 + prod_node_der2 * val2 * val3) except: log_der1 = LOG_ZERO try: log_der2 = log(prod_node_der1 * val1 + prod_node_der2 * val1 * val3) except: log_der2 = LOG_ZERO try: log_der3 = log(prod_node_der2 * val1 * val2) except: log_der3 = LOG_ZERO print('log ders 1:{lgd1} 2:{lgd2} 3:{lgd3}'.format(lgd1=log_der1, lgd2=log_der2, lgd3=log_der3)) print('log ders 1:{lgd1} 2:{lgd2} 3:{lgd3}'.format(lgd1=child1.log_der, lgd2=child2.log_der, lgd3=child3.log_der)) if IS_LOG_ZERO(log_der1): assert IS_LOG_ZERO(child1.log_der) else: assert_almost_equal(log_der1, child1.log_der, 15) if IS_LOG_ZERO(log_der2): assert IS_LOG_ZERO(child2.log_der) else: assert_almost_equal(log_der2, child2.log_der, 15) if IS_LOG_ZERO(log_der3): assert IS_LOG_ZERO(child3.log_der) else: assert_almost_equal(log_der3, child3.log_der, 15) # setting different values for children val1 = 0. child1.set_val(val1) val2 = 0.2 child2.set_val(val2) val3 = 1. child3.set_val(val3) # eval prod_node1.eval() prod_node2.eval() child1.log_der = LOG_ZERO child2.log_der = LOG_ZERO child3.log_der = LOG_ZERO # set der and backprop prod_node_der1 = 0.5 prod_node1.log_der = log(prod_node_der1) prod_node1.backprop() prod_node_der2 = 0.1 prod_node2.log_der = log(prod_node_der2) prod_node2.backprop() # check for correctness try: log_der1 = log(prod_node_der1 * val2 + prod_node_der2 * val2 * val3) except: log_der1 = LOG_ZERO try: log_der2 = log(prod_node_der1 * val1 + prod_node_der2 * val1 * val3) except: log_der2 = LOG_ZERO try: log_der3 = log(prod_node_der2 * val1 * val2) except: log_der3 = LOG_ZERO print('log ders 1:{lgd1} 2:{lgd2} 3:{lgd3}'.format(lgd1=log_der1, lgd2=log_der2, lgd3=log_der3)) print('log ders 1:{lgd1} 2:{lgd2} 3:{lgd3}'.format(lgd1=child1.log_der, lgd2=child2.log_der, lgd3=child3.log_der)) if IS_LOG_ZERO(log_der1): assert IS_LOG_ZERO(child1.log_der) else: assert_almost_equal(log_der1, child1.log_der, 15) if IS_LOG_ZERO(log_der2): assert IS_LOG_ZERO(child2.log_der) else: assert_almost_equal(log_der2, child2.log_der, 15) if IS_LOG_ZERO(log_der3): assert IS_LOG_ZERO(child3.log_der) else: assert_almost_equal(log_der3, child3.log_der, 15)