def test_sum_layer_backprop(): # input layer made of 5 generic nodes node1 = Node() node2 = Node() node3 = Node() node4 = Node() node5 = Node() # top layer made by 3 sum nodes sum1 = SumNode() sum2 = SumNode() sum3 = SumNode() # linking to input nodes weight11 = 0.3 sum1.add_child(node1, weight11) weight12 = 0.3 sum1.add_child(node2, weight12) weight13 = 0.4 sum1.add_child(node3, weight13) weight22 = 0.15 sum2.add_child(node2, weight22) weight23 = 0.15 sum2.add_child(node3, weight23) weight24 = 0.7 sum2.add_child(node4, weight24) weight33 = 0.4 sum3.add_child(node3, weight33) weight34 = 0.25 sum3.add_child(node4, weight34) weight35 = 0.35 sum3.add_child(node5, weight35) sum_layer = SumLayer([sum1, sum2, sum3]) # setting input values val1 = 0.0 node1.set_val(val1) val2 = 0.5 node2.set_val(val2) val3 = 0.3 node3.set_val(val3) val4 = 1.0 node4.set_val(val4) val5 = 0.0 node5.set_val(val5) # evaluating sum_layer.eval() print('eval\'d layer:', sum_layer.node_values()) # set the parent derivatives sum_der1 = 1.0 sum1.log_der = log(sum_der1) sum_der2 = 1.0 sum2.log_der = log(sum_der2) sum_der3 = 0.0 sum3.log_der = LOG_ZERO # back prop layer wise sum_layer.backprop() # check for correctness try: log_der1 = log(sum_der1 * weight11) except: log_der1 = LOG_ZERO try: log_der2 = log(sum_der1 * weight12 + sum_der2 * weight22) except: log_der2 = LOG_ZERO try: log_der3 = log(sum_der1 * weight13 + sum_der2 * weight23 + sum_der3 * weight33) except: log_der3 = LOG_ZERO try: log_der4 = log(sum_der2 * weight24 + sum_der3 * weight34) except: log_der4 = LOG_ZERO try: log_der5 = log(sum_der3 * weight35) except: log_der5 = LOG_ZERO # printing, just in case print('child log der', node1.log_der, node2.log_der, node3.log_der, node4.log_der, node5.log_der) print('exact log der', log_der1, log_der2, log_der3, log_der4, log_der5) if IS_LOG_ZERO(log_der1): assert IS_LOG_ZERO(node1.log_der) else: assert_almost_equal(log_der1, node1.log_der, 15) if IS_LOG_ZERO(log_der2): assert IS_LOG_ZERO(node2.log_der) else: assert_almost_equal(log_der2, node2.log_der, 15) if IS_LOG_ZERO(log_der3): assert IS_LOG_ZERO(node3.log_der) else: assert_almost_equal(log_der3, node3.log_der, 15) if IS_LOG_ZERO(log_der4): assert IS_LOG_ZERO(node4.log_der) else: assert_almost_equal(log_der4, node4.log_der, 15) if IS_LOG_ZERO(log_der5): assert IS_LOG_ZERO(node5.log_der) else: assert_almost_equal(log_der5, node5.log_der, 15) # updating weights eta = 0.1 sum_layer.update_weights(Spn.test_weight_update, 0) # checking for correctness weight_u11 = sum_der1 * val1 * eta + weight11 weight_u12 = sum_der1 * val2 * eta + weight12 weight_u13 = sum_der1 * val3 * eta + weight13 weight_u22 = sum_der2 * val2 * eta + weight22 weight_u23 = sum_der2 * val3 * eta + weight23 weight_u24 = sum_der2 * val4 * eta + weight24 weight_u33 = sum_der3 * val3 * eta + weight33 weight_u34 = sum_der3 * val4 * eta + weight34 weight_u35 = sum_der3 * val5 * eta + weight35 # normalizing weight_sum1 = weight_u11 + weight_u12 + weight_u13 weight_sum2 = weight_u22 + weight_u23 + weight_u24 weight_sum3 = weight_u33 + weight_u34 + weight_u35 weight_u11 = weight_u11 / weight_sum1 weight_u12 = weight_u12 / weight_sum1 weight_u13 = weight_u13 / weight_sum1 weight_u22 = weight_u22 / weight_sum2 weight_u23 = weight_u23 / weight_sum2 weight_u24 = weight_u24 / weight_sum2 weight_u33 = weight_u33 / weight_sum3 weight_u34 = weight_u34 / weight_sum3 weight_u35 = weight_u35 / weight_sum3 print('expected weights', weight_u11, weight_u12, weight_u13, weight_u22, weight_u23, weight_u24, weight_u33, weight_u34, weight_u35) print('found weights', sum1.weights[0], sum1.weights[1], sum1.weights[2], sum2.weights[0], sum2.weights[1], sum2.weights[2], sum3.weights[0], sum3.weights[1], sum3.weights[2]) assert_almost_equal(weight_u11, sum1.weights[0], 10) assert_almost_equal(weight_u12, sum1.weights[1], 10) assert_almost_equal(weight_u13, sum1.weights[2], 10) assert_almost_equal(weight_u22, sum2.weights[0], 10) assert_almost_equal(weight_u23, sum2.weights[1], 10) assert_almost_equal(weight_u24, sum2.weights[2], 10) assert_almost_equal(weight_u33, sum3.weights[0], 10) assert_almost_equal(weight_u34, sum3.weights[1], 10) assert_almost_equal(weight_u35, sum3.weights[2], 10) # # resetting derivatives # node1.log_der = LOG_ZERO node2.log_der = LOG_ZERO node3.log_der = LOG_ZERO node4.log_der = LOG_ZERO node5.log_der = LOG_ZERO # setting new values as inputs val1 = 0.0 node1.set_val(val1) val2 = 0.0 node2.set_val(val2) val3 = 0.3 node3.set_val(val3) val4 = 1.0 node4.set_val(val4) val5 = 1.0 node5.set_val(val5) # evaluating again sum_layer.eval() print('eval\'d layer:', sum_layer.node_values()) # set the parent derivatives sum_der1 = 1.0 sum1.log_der = log(sum_der1) sum_der2 = 1.0 sum2.log_der = log(sum_der2) sum_der3 = 0.0 sum3.log_der = LOG_ZERO # back prop layer wise sum_layer.backprop() # check for correctness try: log_der1 = log(sum_der1 * weight_u11) except: log_der1 = LOG_ZERO try: log_der2 = log(sum_der1 * weight_u12 + sum_der2 * weight_u22) except: log_der2 = LOG_ZERO try: log_der3 = log(sum_der1 * weight_u13 + sum_der2 * weight_u23 + sum_der3 * weight_u33) except: log_der3 = LOG_ZERO try: log_der4 = log(sum_der2 * weight_u24 + sum_der3 * weight_u34) except: log_der4 = LOG_ZERO try: log_der5 = log(sum_der3 * weight_u35) except: log_der5 = LOG_ZERO # printing, just in case print('child log der', node1.log_der, node2.log_der, node3.log_der, node4.log_der, node5.log_der) print('exact log der', log_der1, log_der2, log_der3, log_der4, log_der5) if IS_LOG_ZERO(log_der1): assert IS_LOG_ZERO(node1.log_der) else: assert_almost_equal(log_der1, node1.log_der, 15) if IS_LOG_ZERO(log_der2): assert IS_LOG_ZERO(node2.log_der) else: assert_almost_equal(log_der2, node2.log_der, 15) if IS_LOG_ZERO(log_der3): assert IS_LOG_ZERO(node3.log_der) else: assert_almost_equal(log_der3, node3.log_der, 15) if IS_LOG_ZERO(log_der4): assert IS_LOG_ZERO(node4.log_der) else: assert_almost_equal(log_der4, node4.log_der, 15) if IS_LOG_ZERO(log_der5): assert IS_LOG_ZERO(node5.log_der) else: assert_almost_equal(log_der5, node5.log_der, 15)
def test_sum_node_backprop(): # create child nodes child1 = Node() val1 = 1. child1.set_val(val1) child2 = Node() val2 = 1. child2.set_val(val2) # create sum node and adding children to it sum_node1 = SumNode() weight11 = 0.8 weight12 = 0.2 sum_node1.add_child(child1, weight11) sum_node1.add_child(child2, weight12) # adding a coparent sum_node2 = SumNode() weight21 = 0.6 weight22 = 0.4 sum_node2.add_child(child1, weight21) sum_node2.add_child(child2, weight22) # evaluating sum_node1.eval() sum_node2.eval() # setting the log derivatives to the parents sum_node_der1 = 1.0 sum_node1.log_der = log(sum_node_der1) sum_node1.backprop() sum_node_der2 = 1.0 sum_node2.log_der = log(sum_node_der2) sum_node2.backprop() # checking for correctness log_der1 = log(weight11 * sum_node_der1 + weight21 * sum_node_der2) log_der2 = log(weight12 * sum_node_der1 + weight22 * sum_node_der2) print('log ders 1:{lgd1} 2:{lgd2}'.format(lgd1=log_der1, lgd2=log_der2)) assert_almost_equal(log_der1, child1.log_der, 15) assert_almost_equal(log_der2, child2.log_der, 15) # resetting child1.log_der = LOG_ZERO child2.log_der = LOG_ZERO # now changing the initial der values sum_node_der1 = 0.5 sum_node1.log_der = log(sum_node_der1) sum_node1.backprop() sum_node_der2 = 0.0 sum_node2.log_der = LOG_ZERO sum_node2.backprop() # checking for correctness log_der1 = log(weight11 * sum_node_der1 + weight21 * sum_node_der2) log_der2 = log(weight12 * sum_node_der1 + weight22 * sum_node_der2) print('log ders 1:{lgd1} 2:{lgd2}'.format(lgd1=log_der1, lgd2=log_der2)) assert_almost_equal(log_der1, child1.log_der, 15) assert_almost_equal(log_der2, child2.log_der, 15)