Пример #1
0
def test_product_layer_create_and_eval():
    # creating generic nodes
    node1 = Node()
    node2 = Node()
    node3 = Node()

    # whose values are
    val1 = 0.8
    val2 = 1.
    val3 = 0.
    node1.set_val(val1)
    node2.set_val(val2)
    node3.set_val(val3)

    # creating product nodes
    prod1 = ProductNode()
    prod2 = ProductNode()
    prod3 = ProductNode()

    # adding children
    prod1.add_child(node1)
    prod1.add_child(node2)

    prod2.add_child(node1)
    prod2.add_child(node3)

    prod3.add_child(node2)
    prod3.add_child(node3)

    # adding product nodes to layer
    product_layer = ProductLayer([prod1, prod2, prod3])

    # evaluating
    product_layer.eval()

    # getting log vals
    layer_evals = product_layer.node_values()
    print('layer eval nodes')
    print(layer_evals)

    # computing our values
    prodval1 = val1 * val2
    logval1 = log(prodval1) if prodval1 > 0. else LOG_ZERO
    prodval2 = val1 * val3
    logval2 = log(prodval2) if prodval2 > 0. else LOG_ZERO
    prodval3 = val2 * val3
    logval3 = log(prodval3) if prodval3 > 0. else LOG_ZERO
    logvals = [logval1, logval2, logval3]
    print('log vals')
    print(logvals)

    for logval, eval in zip(logvals, layer_evals):
        if logval == LOG_ZERO:
            # for zero log check this way for correctness
            assert IS_LOG_ZERO(eval) is True
        else:
            assert_almost_equal(logval, eval, PRECISION)
Пример #2
0
def test_prod_layer_backprop():
    # input layer made of 5 generic nodes
    node1 = Node()
    node2 = Node()
    node3 = Node()
    node4 = Node()
    node5 = Node()

    input_layer = CategoricalInputLayer([node1, node2,
                                         node3, node4,
                                         node5])

    # top layer made by 3 prod nodes
    prod1 = ProductNode()
    prod2 = ProductNode()
    prod3 = ProductNode()

    # linking to input nodes
    prod1.add_child(node1)
    prod1.add_child(node2)
    prod1.add_child(node3)

    prod2.add_child(node2)
    prod2.add_child(node3)
    prod2.add_child(node4)

    prod3.add_child(node3)
    prod3.add_child(node4)
    prod3.add_child(node5)

    prod_layer = ProductLayer([prod1, prod2, prod3])

    # setting input values
    val1 = 0.0
    node1.set_val(val1)
    val2 = 0.5
    node2.set_val(val2)
    val3 = 0.3
    node3.set_val(val3)
    val4 = 1.0
    node4.set_val(val4)
    val5 = 0.0
    node5.set_val(val5)

    print('input', [node.log_val for node in input_layer.nodes()])
    # evaluating
    prod_layer.eval()
    print('eval\'d layer:', prod_layer.node_values())

    # set the parent derivatives
    prod_der1 = 1.0
    prod1.log_der = log(prod_der1)

    prod_der2 = 1.0
    prod2.log_der = log(prod_der2)

    prod_der3 = 0.0
    prod3.log_der = LOG_ZERO

    # back prop layer wise
    prod_layer.backprop()

    # check for correctness
    try:
        log_der1 = log(prod_der1 * val2 * val3)
    except:
        log_der1 = LOG_ZERO

    try:
        log_der2 = log(prod_der1 * val1 * val3 +
                       prod_der2 * val3 * val4)
    except:
        log_der2 = LOG_ZERO

    try:
        log_der3 = log(prod_der2 * val2 * val4 +
                       prod_der3 * val4 * val5 +
                       prod_der1 * val1 * val2)
    except:
        log_der3 = LOG_ZERO

    try:
        log_der4 = log(prod_der2 * val2 * val3 +
                       prod_der3 * val3 * val5)
    except:
        log_der4 = LOG_ZERO

    try:
        log_der5 = log(prod_der3 * val3 * val4)
    except:
        log_der5 = LOG_ZERO

    # printing, just in case
    print('child log der', node1.log_der, node2.log_der,
          node3.log_der, node4.log_der, node5.log_der)
    print('exact log der', log_der1, log_der2, log_der3,
          log_der4, log_der5)

    if IS_LOG_ZERO(log_der1):
        assert IS_LOG_ZERO(node1.log_der)
    else:
        assert_almost_equal(log_der1, node1.log_der, 15)
    if IS_LOG_ZERO(log_der2):
        assert IS_LOG_ZERO(node2.log_der)
    else:
        assert_almost_equal(log_der2, node2.log_der, 15)
    if IS_LOG_ZERO(log_der3):
        assert IS_LOG_ZERO(node3.log_der)
    else:
        assert_almost_equal(log_der3, node3.log_der, 15)
    if IS_LOG_ZERO(log_der4):
        assert IS_LOG_ZERO(node4.log_der)
    else:
        assert_almost_equal(log_der4, node4.log_der, 15)
    if IS_LOG_ZERO(log_der5):
        assert IS_LOG_ZERO(node5.log_der)
    else:
        assert_almost_equal(log_der5, node5.log_der, 15)

    # resetting derivatives
    node1.log_der = LOG_ZERO
    node2.log_der = LOG_ZERO
    node3.log_der = LOG_ZERO
    node4.log_der = LOG_ZERO
    node5.log_der = LOG_ZERO

    # setting new values as inputs
    val1 = 0.0
    node1.set_val(val1)
    val2 = 0.0
    node2.set_val(val2)
    val3 = 0.3
    node3.set_val(val3)
    val4 = 1.0
    node4.set_val(val4)
    val5 = 1.0
    node5.set_val(val5)

    # evaluating again
    prod_layer.eval()
    print('eval\'d layer:', prod_layer.node_values())

    # set the parent derivatives
    prod_der1 = 1.0
    prod1.log_der = log(prod_der1)

    prod_der2 = 1.0
    prod2.log_der = log(prod_der2)

    prod_der3 = 0.0
    prod3.log_der = LOG_ZERO

    # back prop layer wise
    prod_layer.backprop()

    # check for correctness
    try:
        log_der1 = log(prod_der1 * val2 * val3)
    except:
        log_der1 = LOG_ZERO

    try:
        log_der2 = log(prod_der1 * val1 * val3 +
                       prod_der2 * val3 * val4)
    except:
        log_der2 = LOG_ZERO

    try:
        log_der3 = log(prod_der2 * val2 * val4 +
                       prod_der3 * val4 * val5 +
                       prod_der1 * val1 * val2)
    except:
        log_der3 = LOG_ZERO

    try:
        log_der4 = log(prod_der2 * val2 * val3 +
                       prod_der3 * val3 * val5)
    except:
        log_der4 = LOG_ZERO

    try:
        log_der5 = log(prod_der3 * val3 * val4)
    except:
        log_der5 = LOG_ZERO

    # printing, just in case
    print('child log der', node1.log_der, node2.log_der,
          node3.log_der, node4.log_der, node5.log_der)
    print('exact log der', log_der1, log_der2, log_der3,
          log_der4, log_der5)

    if IS_LOG_ZERO(log_der1):
        assert IS_LOG_ZERO(node1.log_der)
    else:
        assert_almost_equal(log_der1, node1.log_der, 15)
    if IS_LOG_ZERO(log_der2):
        assert IS_LOG_ZERO(node2.log_der)
    else:
        assert_almost_equal(log_der2, node2.log_der, 15)
    if IS_LOG_ZERO(log_der3):
        assert IS_LOG_ZERO(node3.log_der)
    else:
        assert_almost_equal(log_der3, node3.log_der, 15)
    if IS_LOG_ZERO(log_der4):
        assert IS_LOG_ZERO(node4.log_der)
    else:
        assert_almost_equal(log_der4, node4.log_der, 15)
    if IS_LOG_ZERO(log_der5):
        assert IS_LOG_ZERO(node5.log_der)
    else:
        assert_almost_equal(log_der5, node5.log_der, 15)