Ejemplo n.º 1
0
def test_sum_layer_backprop():
        # input layer made of 5 generic nodes
    node1 = Node()
    node2 = Node()
    node3 = Node()
    node4 = Node()
    node5 = Node()

    # top layer made by 3 sum nodes
    sum1 = SumNode()
    sum2 = SumNode()
    sum3 = SumNode()

    # linking to input nodes
    weight11 = 0.3
    sum1.add_child(node1, weight11)
    weight12 = 0.3
    sum1.add_child(node2, weight12)
    weight13 = 0.4
    sum1.add_child(node3, weight13)

    weight22 = 0.15
    sum2.add_child(node2, weight22)
    weight23 = 0.15
    sum2.add_child(node3, weight23)
    weight24 = 0.7
    sum2.add_child(node4, weight24)

    weight33 = 0.4
    sum3.add_child(node3, weight33)
    weight34 = 0.25
    sum3.add_child(node4, weight34)
    weight35 = 0.35
    sum3.add_child(node5, weight35)

    sum_layer = SumLayer([sum1, sum2, sum3])

    # setting input values
    val1 = 0.0
    node1.set_val(val1)
    val2 = 0.5
    node2.set_val(val2)
    val3 = 0.3
    node3.set_val(val3)
    val4 = 1.0
    node4.set_val(val4)
    val5 = 0.0
    node5.set_val(val5)

    # evaluating
    sum_layer.eval()
    print('eval\'d layer:', sum_layer.node_values())

    # set the parent derivatives
    sum_der1 = 1.0
    sum1.log_der = log(sum_der1)

    sum_der2 = 1.0
    sum2.log_der = log(sum_der2)

    sum_der3 = 0.0
    sum3.log_der = LOG_ZERO

    # back prop layer wise
    sum_layer.backprop()

    # check for correctness
    try:
        log_der1 = log(sum_der1 * weight11)
    except:
        log_der1 = LOG_ZERO

    try:
        log_der2 = log(sum_der1 * weight12 +
                       sum_der2 * weight22)
    except:
        log_der2 = LOG_ZERO

    try:
        log_der3 = log(sum_der1 * weight13 +
                       sum_der2 * weight23 +
                       sum_der3 * weight33)
    except:
        log_der3 = LOG_ZERO

    try:
        log_der4 = log(sum_der2 * weight24 +
                       sum_der3 * weight34)
    except:
        log_der4 = LOG_ZERO

    try:
        log_der5 = log(sum_der3 * weight35)
    except:
        log_der5 = LOG_ZERO

    # printing, just in case
    print('child log der', node1.log_der, node2.log_der,
          node3.log_der, node4.log_der, node5.log_der)
    print('exact log der', log_der1, log_der2, log_der3,
          log_der4, log_der5)

    if IS_LOG_ZERO(log_der1):
        assert IS_LOG_ZERO(node1.log_der)
    else:
        assert_almost_equal(log_der1, node1.log_der, 15)
    if IS_LOG_ZERO(log_der2):
        assert IS_LOG_ZERO(node2.log_der)
    else:
        assert_almost_equal(log_der2, node2.log_der, 15)
    if IS_LOG_ZERO(log_der3):
        assert IS_LOG_ZERO(node3.log_der)
    else:
        assert_almost_equal(log_der3, node3.log_der, 15)
    if IS_LOG_ZERO(log_der4):
        assert IS_LOG_ZERO(node4.log_der)
    else:
        assert_almost_equal(log_der4, node4.log_der, 15)
    if IS_LOG_ZERO(log_der5):
        assert IS_LOG_ZERO(node5.log_der)
    else:
        assert_almost_equal(log_der5, node5.log_der, 15)

    # updating weights
    eta = 0.1
    sum_layer.update_weights(Spn.test_weight_update, 0)
    # checking for correctness
    weight_u11 = sum_der1 * val1 * eta + weight11
    weight_u12 = sum_der1 * val2 * eta + weight12
    weight_u13 = sum_der1 * val3 * eta + weight13

    weight_u22 = sum_der2 * val2 * eta + weight22
    weight_u23 = sum_der2 * val3 * eta + weight23
    weight_u24 = sum_der2 * val4 * eta + weight24

    weight_u33 = sum_der3 * val3 * eta + weight33
    weight_u34 = sum_der3 * val4 * eta + weight34
    weight_u35 = sum_der3 * val5 * eta + weight35

    # normalizing
    weight_sum1 = weight_u11 + weight_u12 + weight_u13
    weight_sum2 = weight_u22 + weight_u23 + weight_u24
    weight_sum3 = weight_u33 + weight_u34 + weight_u35

    weight_u11 = weight_u11 / weight_sum1
    weight_u12 = weight_u12 / weight_sum1
    weight_u13 = weight_u13 / weight_sum1

    weight_u22 = weight_u22 / weight_sum2
    weight_u23 = weight_u23 / weight_sum2
    weight_u24 = weight_u24 / weight_sum2

    weight_u33 = weight_u33 / weight_sum3
    weight_u34 = weight_u34 / weight_sum3
    weight_u35 = weight_u35 / weight_sum3

    print('expected weights', weight_u11, weight_u12, weight_u13,
          weight_u22, weight_u23, weight_u24,
          weight_u33, weight_u34, weight_u35)
    print('found weights', sum1.weights[0], sum1.weights[1], sum1.weights[2],
          sum2.weights[0], sum2.weights[1], sum2.weights[2],
          sum3.weights[0], sum3.weights[1], sum3.weights[2])
    assert_almost_equal(weight_u11, sum1.weights[0], 10)
    assert_almost_equal(weight_u12, sum1.weights[1], 10)
    assert_almost_equal(weight_u13, sum1.weights[2], 10)

    assert_almost_equal(weight_u22, sum2.weights[0], 10)
    assert_almost_equal(weight_u23, sum2.weights[1], 10)
    assert_almost_equal(weight_u24, sum2.weights[2], 10)

    assert_almost_equal(weight_u33, sum3.weights[0], 10)
    assert_almost_equal(weight_u34, sum3.weights[1], 10)
    assert_almost_equal(weight_u35, sum3.weights[2], 10)

    #
    # resetting derivatives
    #
    node1.log_der = LOG_ZERO
    node2.log_der = LOG_ZERO
    node3.log_der = LOG_ZERO
    node4.log_der = LOG_ZERO
    node5.log_der = LOG_ZERO

    # setting new values as inputs
    val1 = 0.0
    node1.set_val(val1)
    val2 = 0.0
    node2.set_val(val2)
    val3 = 0.3
    node3.set_val(val3)
    val4 = 1.0
    node4.set_val(val4)
    val5 = 1.0
    node5.set_val(val5)

    # evaluating again
    sum_layer.eval()
    print('eval\'d layer:', sum_layer.node_values())

    # set the parent derivatives
    sum_der1 = 1.0
    sum1.log_der = log(sum_der1)

    sum_der2 = 1.0
    sum2.log_der = log(sum_der2)

    sum_der3 = 0.0
    sum3.log_der = LOG_ZERO

    # back prop layer wise
    sum_layer.backprop()

    # check for correctness
    try:
        log_der1 = log(sum_der1 * weight_u11)
    except:
        log_der1 = LOG_ZERO

    try:
        log_der2 = log(sum_der1 * weight_u12 +
                       sum_der2 * weight_u22)
    except:
        log_der2 = LOG_ZERO

    try:
        log_der3 = log(sum_der1 * weight_u13 +
                       sum_der2 * weight_u23 +
                       sum_der3 * weight_u33)
    except:
        log_der3 = LOG_ZERO

    try:
        log_der4 = log(sum_der2 * weight_u24 +
                       sum_der3 * weight_u34)
    except:
        log_der4 = LOG_ZERO

    try:
        log_der5 = log(sum_der3 * weight_u35)
    except:
        log_der5 = LOG_ZERO

    # printing, just in case
    print('child log der', node1.log_der, node2.log_der,
          node3.log_der, node4.log_der, node5.log_der)
    print('exact log der', log_der1, log_der2, log_der3,
          log_der4, log_der5)

    if IS_LOG_ZERO(log_der1):
        assert IS_LOG_ZERO(node1.log_der)
    else:
        assert_almost_equal(log_der1, node1.log_der, 15)
    if IS_LOG_ZERO(log_der2):
        assert IS_LOG_ZERO(node2.log_der)
    else:
        assert_almost_equal(log_der2, node2.log_der, 15)
    if IS_LOG_ZERO(log_der3):
        assert IS_LOG_ZERO(node3.log_der)
    else:
        assert_almost_equal(log_der3, node3.log_der, 15)
    if IS_LOG_ZERO(log_der4):
        assert IS_LOG_ZERO(node4.log_der)
    else:
        assert_almost_equal(log_der4, node4.log_der, 15)
    if IS_LOG_ZERO(log_der5):
        assert IS_LOG_ZERO(node5.log_der)
    else:
        assert_almost_equal(log_der5, node5.log_der, 15)
Ejemplo n.º 2
0
def test_sum_node_backprop():
    # create child nodes
    child1 = Node()
    val1 = 1.
    child1.set_val(val1)

    child2 = Node()
    val2 = 1.
    child2.set_val(val2)

    # create sum node and adding children to it
    sum_node1 = SumNode()
    weight11 = 0.8
    weight12 = 0.2
    sum_node1.add_child(child1, weight11)
    sum_node1.add_child(child2, weight12)

    # adding a coparent
    sum_node2 = SumNode()
    weight21 = 0.6
    weight22 = 0.4
    sum_node2.add_child(child1, weight21)
    sum_node2.add_child(child2, weight22)

    # evaluating
    sum_node1.eval()
    sum_node2.eval()

    # setting the log derivatives to the parents
    sum_node_der1 = 1.0
    sum_node1.log_der = log(sum_node_der1)
    sum_node1.backprop()

    sum_node_der2 = 1.0
    sum_node2.log_der = log(sum_node_der2)
    sum_node2.backprop()

    # checking for correctness
    log_der1 = log(weight11 * sum_node_der1 +
                   weight21 * sum_node_der2)

    log_der2 = log(weight12 * sum_node_der1 +
                   weight22 * sum_node_der2)

    print('log ders 1:{lgd1} 2:{lgd2}'.format(lgd1=log_der1,
                                              lgd2=log_der2))
    assert_almost_equal(log_der1, child1.log_der, 15)
    assert_almost_equal(log_der2, child2.log_der, 15)

    # resetting
    child1.log_der = LOG_ZERO
    child2.log_der = LOG_ZERO

    # now changing the initial der values
    sum_node_der1 = 0.5
    sum_node1.log_der = log(sum_node_der1)
    sum_node1.backprop()

    sum_node_der2 = 0.0
    sum_node2.log_der = LOG_ZERO
    sum_node2.backprop()

    # checking for correctness
    log_der1 = log(weight11 * sum_node_der1 +
                   weight21 * sum_node_der2)

    log_der2 = log(weight12 * sum_node_der1 +
                   weight22 * sum_node_der2)

    print('log ders 1:{lgd1} 2:{lgd2}'.format(lgd1=log_der1,
                                              lgd2=log_der2))
    assert_almost_equal(log_der1, child1.log_der, 15)
    assert_almost_equal(log_der2, child2.log_der, 15)