def test_sum_layer_backprop(): # input layer made of 5 generic nodes node1 = Node() node2 = Node() node3 = Node() node4 = Node() node5 = Node() # top layer made by 3 sum nodes sum1 = SumNode() sum2 = SumNode() sum3 = SumNode() # linking to input nodes weight11 = 0.3 sum1.add_child(node1, weight11) weight12 = 0.3 sum1.add_child(node2, weight12) weight13 = 0.4 sum1.add_child(node3, weight13) weight22 = 0.15 sum2.add_child(node2, weight22) weight23 = 0.15 sum2.add_child(node3, weight23) weight24 = 0.7 sum2.add_child(node4, weight24) weight33 = 0.4 sum3.add_child(node3, weight33) weight34 = 0.25 sum3.add_child(node4, weight34) weight35 = 0.35 sum3.add_child(node5, weight35) sum_layer = SumLayer([sum1, sum2, sum3]) # setting input values val1 = 0.0 node1.set_val(val1) val2 = 0.5 node2.set_val(val2) val3 = 0.3 node3.set_val(val3) val4 = 1.0 node4.set_val(val4) val5 = 0.0 node5.set_val(val5) # evaluating sum_layer.eval() print('eval\'d layer:', sum_layer.node_values()) # set the parent derivatives sum_der1 = 1.0 sum1.log_der = log(sum_der1) sum_der2 = 1.0 sum2.log_der = log(sum_der2) sum_der3 = 0.0 sum3.log_der = LOG_ZERO # back prop layer wise sum_layer.backprop() # check for correctness try: log_der1 = log(sum_der1 * weight11) except: log_der1 = LOG_ZERO try: log_der2 = log(sum_der1 * weight12 + sum_der2 * weight22) except: log_der2 = LOG_ZERO try: log_der3 = log(sum_der1 * weight13 + sum_der2 * weight23 + sum_der3 * weight33) except: log_der3 = LOG_ZERO try: log_der4 = log(sum_der2 * weight24 + sum_der3 * weight34) except: log_der4 = LOG_ZERO try: log_der5 = log(sum_der3 * weight35) except: log_der5 = LOG_ZERO # printing, just in case print('child log der', node1.log_der, node2.log_der, node3.log_der, node4.log_der, node5.log_der) print('exact log der', log_der1, log_der2, log_der3, log_der4, log_der5) if IS_LOG_ZERO(log_der1): assert IS_LOG_ZERO(node1.log_der) else: assert_almost_equal(log_der1, node1.log_der, 15) if IS_LOG_ZERO(log_der2): assert IS_LOG_ZERO(node2.log_der) else: assert_almost_equal(log_der2, node2.log_der, 15) if IS_LOG_ZERO(log_der3): assert IS_LOG_ZERO(node3.log_der) else: assert_almost_equal(log_der3, node3.log_der, 15) if IS_LOG_ZERO(log_der4): assert IS_LOG_ZERO(node4.log_der) else: assert_almost_equal(log_der4, node4.log_der, 15) if IS_LOG_ZERO(log_der5): assert IS_LOG_ZERO(node5.log_der) else: assert_almost_equal(log_der5, node5.log_der, 15) # updating weights eta = 0.1 sum_layer.update_weights(Spn.test_weight_update, 0) # checking for correctness weight_u11 = sum_der1 * val1 * eta + weight11 weight_u12 = sum_der1 * val2 * eta + weight12 weight_u13 = sum_der1 * val3 * eta + weight13 weight_u22 = sum_der2 * val2 * eta + weight22 weight_u23 = sum_der2 * val3 * eta + weight23 weight_u24 = sum_der2 * val4 * eta + weight24 weight_u33 = sum_der3 * val3 * eta + weight33 weight_u34 = sum_der3 * val4 * eta + weight34 weight_u35 = sum_der3 * val5 * eta + weight35 # normalizing weight_sum1 = weight_u11 + weight_u12 + weight_u13 weight_sum2 = weight_u22 + weight_u23 + weight_u24 weight_sum3 = weight_u33 + weight_u34 + weight_u35 weight_u11 = weight_u11 / weight_sum1 weight_u12 = weight_u12 / weight_sum1 weight_u13 = weight_u13 / weight_sum1 weight_u22 = weight_u22 / weight_sum2 weight_u23 = weight_u23 / weight_sum2 weight_u24 = weight_u24 / weight_sum2 weight_u33 = weight_u33 / weight_sum3 weight_u34 = weight_u34 / weight_sum3 weight_u35 = weight_u35 / weight_sum3 print('expected weights', weight_u11, weight_u12, weight_u13, weight_u22, weight_u23, weight_u24, weight_u33, weight_u34, weight_u35) print('found weights', sum1.weights[0], sum1.weights[1], sum1.weights[2], sum2.weights[0], sum2.weights[1], sum2.weights[2], sum3.weights[0], sum3.weights[1], sum3.weights[2]) assert_almost_equal(weight_u11, sum1.weights[0], 10) assert_almost_equal(weight_u12, sum1.weights[1], 10) assert_almost_equal(weight_u13, sum1.weights[2], 10) assert_almost_equal(weight_u22, sum2.weights[0], 10) assert_almost_equal(weight_u23, sum2.weights[1], 10) assert_almost_equal(weight_u24, sum2.weights[2], 10) assert_almost_equal(weight_u33, sum3.weights[0], 10) assert_almost_equal(weight_u34, sum3.weights[1], 10) assert_almost_equal(weight_u35, sum3.weights[2], 10) # # resetting derivatives # node1.log_der = LOG_ZERO node2.log_der = LOG_ZERO node3.log_der = LOG_ZERO node4.log_der = LOG_ZERO node5.log_der = LOG_ZERO # setting new values as inputs val1 = 0.0 node1.set_val(val1) val2 = 0.0 node2.set_val(val2) val3 = 0.3 node3.set_val(val3) val4 = 1.0 node4.set_val(val4) val5 = 1.0 node5.set_val(val5) # evaluating again sum_layer.eval() print('eval\'d layer:', sum_layer.node_values()) # set the parent derivatives sum_der1 = 1.0 sum1.log_der = log(sum_der1) sum_der2 = 1.0 sum2.log_der = log(sum_der2) sum_der3 = 0.0 sum3.log_der = LOG_ZERO # back prop layer wise sum_layer.backprop() # check for correctness try: log_der1 = log(sum_der1 * weight_u11) except: log_der1 = LOG_ZERO try: log_der2 = log(sum_der1 * weight_u12 + sum_der2 * weight_u22) except: log_der2 = LOG_ZERO try: log_der3 = log(sum_der1 * weight_u13 + sum_der2 * weight_u23 + sum_der3 * weight_u33) except: log_der3 = LOG_ZERO try: log_der4 = log(sum_der2 * weight_u24 + sum_der3 * weight_u34) except: log_der4 = LOG_ZERO try: log_der5 = log(sum_der3 * weight_u35) except: log_der5 = LOG_ZERO # printing, just in case print('child log der', node1.log_der, node2.log_der, node3.log_der, node4.log_der, node5.log_der) print('exact log der', log_der1, log_der2, log_der3, log_der4, log_der5) if IS_LOG_ZERO(log_der1): assert IS_LOG_ZERO(node1.log_der) else: assert_almost_equal(log_der1, node1.log_der, 15) if IS_LOG_ZERO(log_der2): assert IS_LOG_ZERO(node2.log_der) else: assert_almost_equal(log_der2, node2.log_der, 15) if IS_LOG_ZERO(log_der3): assert IS_LOG_ZERO(node3.log_der) else: assert_almost_equal(log_der3, node3.log_der, 15) if IS_LOG_ZERO(log_der4): assert IS_LOG_ZERO(node4.log_der) else: assert_almost_equal(log_der4, node4.log_der, 15) if IS_LOG_ZERO(log_der5): assert IS_LOG_ZERO(node5.log_der) else: assert_almost_equal(log_der5, node5.log_der, 15)
def test_sum_layer_create_and_eval(): # creating generic nodes node1 = Node() node2 = Node() node3 = Node() # whose values are val1 = 1. val2 = 1. val3 = 0. node1.set_val(val1) node2.set_val(val2) node3.set_val(val3) # setting weights weight11 = 0.2 weight12 = 0.3 weight13 = 0.5 weight21 = 0.3 weight22 = 0.7 weight32 = 0.4 weight33 = 0.6 # creating sum nodes sum1 = SumNode() sum2 = SumNode() sum3 = SumNode() # adding children sum1.add_child(node1, weight11) sum1.add_child(node2, weight12) sum1.add_child(node3, weight13) sum2.add_child(node1, weight21) sum2.add_child(node2, weight22) sum3.add_child(node2, weight32) sum3.add_child(node3, weight33) # adding to layer sum_layer = SumLayer([sum1, sum2, sum3]) # evaluation sum_layer.eval() # computing 'log values by hand' layer_evals = sum_layer.node_values() print('Layer eval nodes') print(layer_evals) logval1 = log(weight11 * val1 + weight12 * val2 + weight13 * val3) logval2 = log(weight21 * val1 + weight22 * val2) logval3 = log(weight32 * val2 + weight33 * val3) logvals = [logval1, logval2, logval3] print('log vals') print(logvals) # checking for correctness for logval, eval in zip(logvals, layer_evals): assert_almost_equal(logval, eval, PRECISION)