Beispiel #1
0
def assert_log_array_almost_equal(array_1, array_2):
    """
    WRITEME
    """
    # if it is a numpy array, then flatten it
    try:
        array_1 = array_1.flatten()
    except:
        try:
            array_1 = flatten_list(array_1)
        except:
            pass
    try:
        array_2 = array_2.flatten()
    except:
        try:
            array_2 = flatten_list(array_2)
        except:
            pass

    for elem_1, elem_2 in zip(array_1, array_2):
        if elem_1 == LOG_ZERO:
            # -2000 == -1000 since exp(-2000) == exp(-1000)
            assert IS_LOG_ZERO(elem_2) is True
        else:
            assert_almost_equal(elem_1, elem_2, PRECISION)
Beispiel #2
0
def test_product_layer_create_and_eval():
    # creating generic nodes
    node1 = Node()
    node2 = Node()
    node3 = Node()

    # whose values are
    val1 = 0.8
    val2 = 1.
    val3 = 0.
    node1.set_val(val1)
    node2.set_val(val2)
    node3.set_val(val3)

    # creating product nodes
    prod1 = ProductNode()
    prod2 = ProductNode()
    prod3 = ProductNode()

    # adding children
    prod1.add_child(node1)
    prod1.add_child(node2)

    prod2.add_child(node1)
    prod2.add_child(node3)

    prod3.add_child(node2)
    prod3.add_child(node3)

    # adding product nodes to layer
    product_layer = ProductLayer([prod1, prod2, prod3])

    # evaluating
    product_layer.eval()

    # getting log vals
    layer_evals = product_layer.node_values()
    print('layer eval nodes')
    print(layer_evals)

    # computing our values
    prodval1 = val1 * val2
    logval1 = log(prodval1) if prodval1 > 0. else LOG_ZERO
    prodval2 = val1 * val3
    logval2 = log(prodval2) if prodval2 > 0. else LOG_ZERO
    prodval3 = val2 * val3
    logval3 = log(prodval3) if prodval3 > 0. else LOG_ZERO
    logvals = [logval1, logval2, logval3]
    print('log vals')
    print(logvals)

    for logval, eval in zip(logvals, layer_evals):
        if logval == LOG_ZERO:
            # for zero log check this way for correctness
            assert IS_LOG_ZERO(eval) is True
        else:
            assert_almost_equal(logval, eval, PRECISION)
Beispiel #3
0
def test_categorical_input_layer():
    print('categorical input layer')
    # I could loop through alpha as well
    alpha = 0.1

    for var_id1 in range(len(vars)):
        for var_id2 in range(len(vars)):
            for var_val1 in range(vars[var_id1]):
                print('varid1, varid2, varval1',
                      var_id1, var_id2, var_val1)
                # var_id1 = 0
                # var_val1 = 0
                node1 = CategoricalIndicatorNode(var_id1,
                                                 var_val1)
                # var_id2 = 0
                var_vals2 = vars[var_id2]
                node2 = CategoricalSmoothedNode(
                    var_id2, var_vals2, alpha, freqs[var_id2])

                # creating the generic input layer
                input_layer = CategoricalInputLayer([node1,
                                                     node2])

                # evaluating according to an observation
                input_layer.eval(obs)

                layer_evals = input_layer.node_values()
                print('layer eval nodes')
                print(layer_evals)

                # computing evaluation by hand
                val1 = 1 if var_val1 == obs[var_id1] or obs[
                    var_id1] == MARG_IND else 0
                logval1 = log(val1) if val1 == 1 else LOG_ZERO

                logval2 = compute_smoothed_ll(
                    obs[var_id2], freqs[var_id2], vars[var_id2], alpha)
                logvals = [logval1, logval2]
                print('log vals')
                print(logvals)

                for logval, eval in zip(logvals, layer_evals):
                    if logval == LOG_ZERO:
                        # for zero log check this way for correctness
                        assert IS_LOG_ZERO(eval) is True
                    else:
                        assert_almost_equal(logval, eval, PRECISION)
Beispiel #4
0
def test_product_node_create_and_eval():
    # create child nodes
    child1 = Node()
    val1 = 1.
    child1.set_val(val1)

    child2 = Node()
    val2 = 1.
    child2.set_val(val2)

    # create product node and add children
    prod_node = ProductNode()
    prod_node.add_child(child1)
    prod_node.add_child(child2)
    assert len(prod_node.children) == 2

    print(prod_node)

    # evaluation
    prod_node.eval()
    print(prod_node.log_val)
    assert_almost_equal(prod_node.log_val,
                        log(val1 * val2),
                        places=15)

    # changing values 0,1 -> LOG_ZERO
    val1 = 0.
    child1.set_val(val1)
    val2 = 1.
    child2.set_val(val2)

    prod_node.eval()
    print(prod_node.log_val)
    assert_almost_equal(prod_node.log_val,
                        LOG_ZERO,
                        places=15)

    # changing values 0,1 -> LOG_ZERO
    val1 = 0.
    child1.set_val(val1)
    val2 = 0.
    child2.set_val(val2)

    prod_node.eval()
    print(prod_node.log_val)
    # now testing with macro since -1000 + -1000 != -1000
    assert IS_LOG_ZERO(prod_node.log_val) is True
Beispiel #5
0
def test_sum_layer_backprop():
        # input layer made of 5 generic nodes
    node1 = Node()
    node2 = Node()
    node3 = Node()
    node4 = Node()
    node5 = Node()

    # top layer made by 3 sum nodes
    sum1 = SumNode()
    sum2 = SumNode()
    sum3 = SumNode()

    # linking to input nodes
    weight11 = 0.3
    sum1.add_child(node1, weight11)
    weight12 = 0.3
    sum1.add_child(node2, weight12)
    weight13 = 0.4
    sum1.add_child(node3, weight13)

    weight22 = 0.15
    sum2.add_child(node2, weight22)
    weight23 = 0.15
    sum2.add_child(node3, weight23)
    weight24 = 0.7
    sum2.add_child(node4, weight24)

    weight33 = 0.4
    sum3.add_child(node3, weight33)
    weight34 = 0.25
    sum3.add_child(node4, weight34)
    weight35 = 0.35
    sum3.add_child(node5, weight35)

    sum_layer = SumLayer([sum1, sum2, sum3])

    # setting input values
    val1 = 0.0
    node1.set_val(val1)
    val2 = 0.5
    node2.set_val(val2)
    val3 = 0.3
    node3.set_val(val3)
    val4 = 1.0
    node4.set_val(val4)
    val5 = 0.0
    node5.set_val(val5)

    # evaluating
    sum_layer.eval()
    print('eval\'d layer:', sum_layer.node_values())

    # set the parent derivatives
    sum_der1 = 1.0
    sum1.log_der = log(sum_der1)

    sum_der2 = 1.0
    sum2.log_der = log(sum_der2)

    sum_der3 = 0.0
    sum3.log_der = LOG_ZERO

    # back prop layer wise
    sum_layer.backprop()

    # check for correctness
    try:
        log_der1 = log(sum_der1 * weight11)
    except:
        log_der1 = LOG_ZERO

    try:
        log_der2 = log(sum_der1 * weight12 +
                       sum_der2 * weight22)
    except:
        log_der2 = LOG_ZERO

    try:
        log_der3 = log(sum_der1 * weight13 +
                       sum_der2 * weight23 +
                       sum_der3 * weight33)
    except:
        log_der3 = LOG_ZERO

    try:
        log_der4 = log(sum_der2 * weight24 +
                       sum_der3 * weight34)
    except:
        log_der4 = LOG_ZERO

    try:
        log_der5 = log(sum_der3 * weight35)
    except:
        log_der5 = LOG_ZERO

    # printing, just in case
    print('child log der', node1.log_der, node2.log_der,
          node3.log_der, node4.log_der, node5.log_der)
    print('exact log der', log_der1, log_der2, log_der3,
          log_der4, log_der5)

    if IS_LOG_ZERO(log_der1):
        assert IS_LOG_ZERO(node1.log_der)
    else:
        assert_almost_equal(log_der1, node1.log_der, 15)
    if IS_LOG_ZERO(log_der2):
        assert IS_LOG_ZERO(node2.log_der)
    else:
        assert_almost_equal(log_der2, node2.log_der, 15)
    if IS_LOG_ZERO(log_der3):
        assert IS_LOG_ZERO(node3.log_der)
    else:
        assert_almost_equal(log_der3, node3.log_der, 15)
    if IS_LOG_ZERO(log_der4):
        assert IS_LOG_ZERO(node4.log_der)
    else:
        assert_almost_equal(log_der4, node4.log_der, 15)
    if IS_LOG_ZERO(log_der5):
        assert IS_LOG_ZERO(node5.log_der)
    else:
        assert_almost_equal(log_der5, node5.log_der, 15)

    # updating weights
    eta = 0.1
    sum_layer.update_weights(Spn.test_weight_update, 0)
    # checking for correctness
    weight_u11 = sum_der1 * val1 * eta + weight11
    weight_u12 = sum_der1 * val2 * eta + weight12
    weight_u13 = sum_der1 * val3 * eta + weight13

    weight_u22 = sum_der2 * val2 * eta + weight22
    weight_u23 = sum_der2 * val3 * eta + weight23
    weight_u24 = sum_der2 * val4 * eta + weight24

    weight_u33 = sum_der3 * val3 * eta + weight33
    weight_u34 = sum_der3 * val4 * eta + weight34
    weight_u35 = sum_der3 * val5 * eta + weight35

    # normalizing
    weight_sum1 = weight_u11 + weight_u12 + weight_u13
    weight_sum2 = weight_u22 + weight_u23 + weight_u24
    weight_sum3 = weight_u33 + weight_u34 + weight_u35

    weight_u11 = weight_u11 / weight_sum1
    weight_u12 = weight_u12 / weight_sum1
    weight_u13 = weight_u13 / weight_sum1

    weight_u22 = weight_u22 / weight_sum2
    weight_u23 = weight_u23 / weight_sum2
    weight_u24 = weight_u24 / weight_sum2

    weight_u33 = weight_u33 / weight_sum3
    weight_u34 = weight_u34 / weight_sum3
    weight_u35 = weight_u35 / weight_sum3

    print('expected weights', weight_u11, weight_u12, weight_u13,
          weight_u22, weight_u23, weight_u24,
          weight_u33, weight_u34, weight_u35)
    print('found weights', sum1.weights[0], sum1.weights[1], sum1.weights[2],
          sum2.weights[0], sum2.weights[1], sum2.weights[2],
          sum3.weights[0], sum3.weights[1], sum3.weights[2])
    assert_almost_equal(weight_u11, sum1.weights[0], 10)
    assert_almost_equal(weight_u12, sum1.weights[1], 10)
    assert_almost_equal(weight_u13, sum1.weights[2], 10)

    assert_almost_equal(weight_u22, sum2.weights[0], 10)
    assert_almost_equal(weight_u23, sum2.weights[1], 10)
    assert_almost_equal(weight_u24, sum2.weights[2], 10)

    assert_almost_equal(weight_u33, sum3.weights[0], 10)
    assert_almost_equal(weight_u34, sum3.weights[1], 10)
    assert_almost_equal(weight_u35, sum3.weights[2], 10)

    #
    # resetting derivatives
    #
    node1.log_der = LOG_ZERO
    node2.log_der = LOG_ZERO
    node3.log_der = LOG_ZERO
    node4.log_der = LOG_ZERO
    node5.log_der = LOG_ZERO

    # setting new values as inputs
    val1 = 0.0
    node1.set_val(val1)
    val2 = 0.0
    node2.set_val(val2)
    val3 = 0.3
    node3.set_val(val3)
    val4 = 1.0
    node4.set_val(val4)
    val5 = 1.0
    node5.set_val(val5)

    # evaluating again
    sum_layer.eval()
    print('eval\'d layer:', sum_layer.node_values())

    # set the parent derivatives
    sum_der1 = 1.0
    sum1.log_der = log(sum_der1)

    sum_der2 = 1.0
    sum2.log_der = log(sum_der2)

    sum_der3 = 0.0
    sum3.log_der = LOG_ZERO

    # back prop layer wise
    sum_layer.backprop()

    # check for correctness
    try:
        log_der1 = log(sum_der1 * weight_u11)
    except:
        log_der1 = LOG_ZERO

    try:
        log_der2 = log(sum_der1 * weight_u12 +
                       sum_der2 * weight_u22)
    except:
        log_der2 = LOG_ZERO

    try:
        log_der3 = log(sum_der1 * weight_u13 +
                       sum_der2 * weight_u23 +
                       sum_der3 * weight_u33)
    except:
        log_der3 = LOG_ZERO

    try:
        log_der4 = log(sum_der2 * weight_u24 +
                       sum_der3 * weight_u34)
    except:
        log_der4 = LOG_ZERO

    try:
        log_der5 = log(sum_der3 * weight_u35)
    except:
        log_der5 = LOG_ZERO

    # printing, just in case
    print('child log der', node1.log_der, node2.log_der,
          node3.log_der, node4.log_der, node5.log_der)
    print('exact log der', log_der1, log_der2, log_der3,
          log_der4, log_der5)

    if IS_LOG_ZERO(log_der1):
        assert IS_LOG_ZERO(node1.log_der)
    else:
        assert_almost_equal(log_der1, node1.log_der, 15)
    if IS_LOG_ZERO(log_der2):
        assert IS_LOG_ZERO(node2.log_der)
    else:
        assert_almost_equal(log_der2, node2.log_der, 15)
    if IS_LOG_ZERO(log_der3):
        assert IS_LOG_ZERO(node3.log_der)
    else:
        assert_almost_equal(log_der3, node3.log_der, 15)
    if IS_LOG_ZERO(log_der4):
        assert IS_LOG_ZERO(node4.log_der)
    else:
        assert_almost_equal(log_der4, node4.log_der, 15)
    if IS_LOG_ZERO(log_der5):
        assert IS_LOG_ZERO(node5.log_der)
    else:
        assert_almost_equal(log_der5, node5.log_der, 15)
Beispiel #6
0
def test_prod_layer_backprop():
    # input layer made of 5 generic nodes
    node1 = Node()
    node2 = Node()
    node3 = Node()
    node4 = Node()
    node5 = Node()

    input_layer = CategoricalInputLayer([node1, node2,
                                         node3, node4,
                                         node5])

    # top layer made by 3 prod nodes
    prod1 = ProductNode()
    prod2 = ProductNode()
    prod3 = ProductNode()

    # linking to input nodes
    prod1.add_child(node1)
    prod1.add_child(node2)
    prod1.add_child(node3)

    prod2.add_child(node2)
    prod2.add_child(node3)
    prod2.add_child(node4)

    prod3.add_child(node3)
    prod3.add_child(node4)
    prod3.add_child(node5)

    prod_layer = ProductLayer([prod1, prod2, prod3])

    # setting input values
    val1 = 0.0
    node1.set_val(val1)
    val2 = 0.5
    node2.set_val(val2)
    val3 = 0.3
    node3.set_val(val3)
    val4 = 1.0
    node4.set_val(val4)
    val5 = 0.0
    node5.set_val(val5)

    print('input', [node.log_val for node in input_layer.nodes()])
    # evaluating
    prod_layer.eval()
    print('eval\'d layer:', prod_layer.node_values())

    # set the parent derivatives
    prod_der1 = 1.0
    prod1.log_der = log(prod_der1)

    prod_der2 = 1.0
    prod2.log_der = log(prod_der2)

    prod_der3 = 0.0
    prod3.log_der = LOG_ZERO

    # back prop layer wise
    prod_layer.backprop()

    # check for correctness
    try:
        log_der1 = log(prod_der1 * val2 * val3)
    except:
        log_der1 = LOG_ZERO

    try:
        log_der2 = log(prod_der1 * val1 * val3 +
                       prod_der2 * val3 * val4)
    except:
        log_der2 = LOG_ZERO

    try:
        log_der3 = log(prod_der2 * val2 * val4 +
                       prod_der3 * val4 * val5 +
                       prod_der1 * val1 * val2)
    except:
        log_der3 = LOG_ZERO

    try:
        log_der4 = log(prod_der2 * val2 * val3 +
                       prod_der3 * val3 * val5)
    except:
        log_der4 = LOG_ZERO

    try:
        log_der5 = log(prod_der3 * val3 * val4)
    except:
        log_der5 = LOG_ZERO

    # printing, just in case
    print('child log der', node1.log_der, node2.log_der,
          node3.log_der, node4.log_der, node5.log_der)
    print('exact log der', log_der1, log_der2, log_der3,
          log_der4, log_der5)

    if IS_LOG_ZERO(log_der1):
        assert IS_LOG_ZERO(node1.log_der)
    else:
        assert_almost_equal(log_der1, node1.log_der, 15)
    if IS_LOG_ZERO(log_der2):
        assert IS_LOG_ZERO(node2.log_der)
    else:
        assert_almost_equal(log_der2, node2.log_der, 15)
    if IS_LOG_ZERO(log_der3):
        assert IS_LOG_ZERO(node3.log_der)
    else:
        assert_almost_equal(log_der3, node3.log_der, 15)
    if IS_LOG_ZERO(log_der4):
        assert IS_LOG_ZERO(node4.log_der)
    else:
        assert_almost_equal(log_der4, node4.log_der, 15)
    if IS_LOG_ZERO(log_der5):
        assert IS_LOG_ZERO(node5.log_der)
    else:
        assert_almost_equal(log_der5, node5.log_der, 15)

    # resetting derivatives
    node1.log_der = LOG_ZERO
    node2.log_der = LOG_ZERO
    node3.log_der = LOG_ZERO
    node4.log_der = LOG_ZERO
    node5.log_der = LOG_ZERO

    # setting new values as inputs
    val1 = 0.0
    node1.set_val(val1)
    val2 = 0.0
    node2.set_val(val2)
    val3 = 0.3
    node3.set_val(val3)
    val4 = 1.0
    node4.set_val(val4)
    val5 = 1.0
    node5.set_val(val5)

    # evaluating again
    prod_layer.eval()
    print('eval\'d layer:', prod_layer.node_values())

    # set the parent derivatives
    prod_der1 = 1.0
    prod1.log_der = log(prod_der1)

    prod_der2 = 1.0
    prod2.log_der = log(prod_der2)

    prod_der3 = 0.0
    prod3.log_der = LOG_ZERO

    # back prop layer wise
    prod_layer.backprop()

    # check for correctness
    try:
        log_der1 = log(prod_der1 * val2 * val3)
    except:
        log_der1 = LOG_ZERO

    try:
        log_der2 = log(prod_der1 * val1 * val3 +
                       prod_der2 * val3 * val4)
    except:
        log_der2 = LOG_ZERO

    try:
        log_der3 = log(prod_der2 * val2 * val4 +
                       prod_der3 * val4 * val5 +
                       prod_der1 * val1 * val2)
    except:
        log_der3 = LOG_ZERO

    try:
        log_der4 = log(prod_der2 * val2 * val3 +
                       prod_der3 * val3 * val5)
    except:
        log_der4 = LOG_ZERO

    try:
        log_der5 = log(prod_der3 * val3 * val4)
    except:
        log_der5 = LOG_ZERO

    # printing, just in case
    print('child log der', node1.log_der, node2.log_der,
          node3.log_der, node4.log_der, node5.log_der)
    print('exact log der', log_der1, log_der2, log_der3,
          log_der4, log_der5)

    if IS_LOG_ZERO(log_der1):
        assert IS_LOG_ZERO(node1.log_der)
    else:
        assert_almost_equal(log_der1, node1.log_der, 15)
    if IS_LOG_ZERO(log_der2):
        assert IS_LOG_ZERO(node2.log_der)
    else:
        assert_almost_equal(log_der2, node2.log_der, 15)
    if IS_LOG_ZERO(log_der3):
        assert IS_LOG_ZERO(node3.log_der)
    else:
        assert_almost_equal(log_der3, node3.log_der, 15)
    if IS_LOG_ZERO(log_der4):
        assert IS_LOG_ZERO(node4.log_der)
    else:
        assert_almost_equal(log_der4, node4.log_der, 15)
    if IS_LOG_ZERO(log_der5):
        assert IS_LOG_ZERO(node5.log_der)
    else:
        assert_almost_equal(log_der5, node5.log_der, 15)
Beispiel #7
0
def test_spn_mpe_eval_and_traversal():
    # create initial layer
    node1 = Node()
    node2 = Node()
    node3 = Node()
    node4 = Node()
    node5 = Node()

    input_layer = CategoricalInputLayer([node1, node2,
                                         node3, node4,
                                         node5])

    # top layer made by 3 sum nodes
    sum1 = SumNode()
    sum2 = SumNode()
    sum3 = SumNode()

    # linking to input nodes
    weight11 = 0.3
    sum1.add_child(node1, weight11)
    weight12 = 0.3
    sum1.add_child(node2, weight12)
    weight13 = 0.4
    sum1.add_child(node3, weight13)

    weight22 = 0.15
    sum2.add_child(node2, weight22)
    weight23 = 0.15
    sum2.add_child(node3, weight23)
    weight24 = 0.7
    sum2.add_child(node4, weight24)

    weight33 = 0.4
    sum3.add_child(node3, weight33)
    weight34 = 0.25
    sum3.add_child(node4, weight34)
    weight35 = 0.35
    sum3.add_child(node5, weight35)

    sum_layer = SumLayer([sum1, sum2, sum3])

    # another layer with two product nodes
    prod1 = ProductNode()
    prod2 = ProductNode()

    prod1.add_child(sum1)
    prod1.add_child(sum2)
    prod2.add_child(sum2)
    prod2.add_child(sum3)

    prod_layer = ProductLayer([prod1, prod2])

    # root layer, double sum
    root1 = SumNode()
    root2 = SumNode()

    weightr11 = 0.5
    root1.add_child(prod1, weightr11)
    weightr12 = 0.5
    root1.add_child(prod2, weightr12)

    weightr21 = 0.9
    root2.add_child(prod1, weightr21)
    weightr22 = 0.1
    root2.add_child(prod2, weightr22)

    root_layer = SumLayer([root1, root2])

    # create the spn
    spn = Spn(input_layer=input_layer,
              layers=[sum_layer, prod_layer, root_layer])

    print('===================')
    print(spn)
    print('===================')

    # setting the input values
    val1 = 0.0
    node1.set_val(val1)
    val2 = 0.5
    node2.set_val(val2)
    val3 = 0.3
    node3.set_val(val3)
    val4 = 1.0
    node4.set_val(val4)
    val5 = 0.0
    node5.set_val(val5)

    # evaluating the spn with MPE inference
    res = spn.test_mpe_eval()
    print('spn eval\'d', res)

    # testing it
    #
    # testing the max layer
    max1 = max(val1 * weight11,
               val2 * weight12,
               val3 * weight13)
    max2 = max(val2 * weight22,
               val3 * weight23,
               val4 * weight24)
    max3 = max(val3 * weight33,
               val4 * weight34,
               val5 * weight35)
    log_max1 = log(max1) if not numpy.isclose(max1, 0) else LOG_ZERO
    log_max2 = log(max2) if not numpy.isclose(max2, 0) else LOG_ZERO
    log_max3 = log(max3) if not numpy.isclose(max3, 0) else LOG_ZERO

    print('expected max vals {0}, {1}, {2}'.format(log_max1,
                                                   log_max2,
                                                   log_max3))
    print('found    max vals {0}, {1}, {2}'.format(sum1.log_val,
                                                   sum2.log_val,
                                                   sum3.log_val))
    if IS_LOG_ZERO(log_max1):
        assert IS_LOG_ZERO(sum1.log_val)
    else:
        assert_almost_equal(log_max1, sum1.log_val)
    if IS_LOG_ZERO(log_max2):
        assert IS_LOG_ZERO(sum2.log_val)
    else:
        assert_almost_equal(log_max2, sum2.log_val)
    if IS_LOG_ZERO(log_max3):
        assert IS_LOG_ZERO(sum3.log_val)
    else:
        assert_almost_equal(log_max3, sum3.log_val)

    # product layer is assumed to be fine, but let's check
    # it anyways
    prod_val1 = max1 * max2
    prod_val2 = max2 * max3
    prod_log_val1 = log_max1 + log_max2
    prod_log_val2 = log_max2 + log_max3

    print('exp prod vals {0}, {1}'.format(prod_log_val1,
                                          prod_log_val2))
    print('rea prod vals {0}, {1}'.format(prod1.log_val,
                                          prod2.log_val))
    if IS_LOG_ZERO(prod_log_val1):
        assert IS_LOG_ZERO(prod1.log_val)
    else:
        assert_almost_equal(prod_log_val1, prod1.log_val)

    if IS_LOG_ZERO(prod_log_val2):
        assert IS_LOG_ZERO(prod2.log_val)
    else:
        assert_almost_equal(prod_log_val2, prod2.log_val)

    # root layer, again a sum layer
    root_val1 = max(prod_val1 * weightr11,
                    prod_val2 * weightr12)
    root_val2 = max(prod_val1 * weightr21,
                    prod_val2 * weightr22)
    root_log_val1 = log(root_val1) if not numpy.isclose(
        root_val1, 0) else LOG_ZERO
    root_log_val2 = log(root_val2) if not numpy.isclose(
        root_val2, 0) else LOG_ZERO

    print('exp root vals {0}, {1}'.format(root_log_val1,
                                          root_log_val2))
    print('found ro vals {0}, {1}'.format(root1.log_val,
                                          root2.log_val))

    if IS_LOG_ZERO(root_log_val1):
        assert IS_LOG_ZERO(root1.log_val)
    else:
        assert_almost_equal(root_log_val1, root1.log_val)
    if IS_LOG_ZERO(root_log_val2):
        assert IS_LOG_ZERO(root2.log_val)
    else:
        assert_almost_equal(root_log_val2, root2.log_val)

    # now we are traversing top down the net
    print('mpe traversing')
    for i, j, k in spn.mpe_traversal():
        print(i, j, k)
Beispiel #8
0
def test_spn_backprop():
    # create initial layer
    node1 = Node()
    node2 = Node()
    node3 = Node()
    node4 = Node()
    node5 = Node()

    input_layer = CategoricalInputLayer([node1, node2,
                                         node3, node4,
                                         node5])

    # top layer made by 3 sum nodes
    sum1 = SumNode()
    sum2 = SumNode()
    sum3 = SumNode()

    # linking to input nodes
    weight11 = 0.3
    sum1.add_child(node1, weight11)
    weight12 = 0.3
    sum1.add_child(node2, weight12)
    weight13 = 0.4
    sum1.add_child(node3, weight13)

    weight22 = 0.15
    sum2.add_child(node2, weight22)
    weight23 = 0.15
    sum2.add_child(node3, weight23)
    weight24 = 0.7
    sum2.add_child(node4, weight24)

    weight33 = 0.4
    sum3.add_child(node3, weight33)
    weight34 = 0.25
    sum3.add_child(node4, weight34)
    weight35 = 0.35
    sum3.add_child(node5, weight35)

    sum_layer = SumLayer([sum1, sum2, sum3])

    # another layer with two product nodes
    prod1 = ProductNode()
    prod2 = ProductNode()

    prod1.add_child(sum1)
    prod1.add_child(sum2)
    prod2.add_child(sum2)
    prod2.add_child(sum3)

    prod_layer = ProductLayer([prod1, prod2])

    # root layer, double sum
    root1 = SumNode()
    root2 = SumNode()

    weightr11 = 0.5
    root1.add_child(prod1, weightr11)
    weightr12 = 0.5
    root1.add_child(prod2, weightr12)

    weightr21 = 0.9
    root2.add_child(prod1, weightr21)
    weightr22 = 0.1
    root2.add_child(prod2, weightr22)

    root_layer = SumLayer([root1, root2])
    # root_layer = SumLayer([root1])

    # create the spn
    spn = Spn(input_layer=input_layer,
              layers=[sum_layer, prod_layer, root_layer])

    # setting the input values
    val1 = 0.0
    node1.set_val(val1)
    val2 = 0.5
    node2.set_val(val2)
    val3 = 0.3
    node3.set_val(val3)
    val4 = 1.0
    node4.set_val(val4)
    val5 = 0.0
    node5.set_val(val5)

    # evaluating the spn
    res = spn.test_eval()
    print('spn eval\'d', res)

    # backprop
    spn.backprop()

    # computing derivatives by hand
    # topdown: root layer
    root_der = 1.0
    log_root_der = log(root_der)

    # print('root ders', root1.log_der, root2.log_der)
    print('root ders', root1.log_der)
    assert_almost_equal(log_root_der, root1.log_der)
    assert_almost_equal(log_root_der, root2.log_der)

    # product layer
    prod_der1 = (root_der * weightr11 +
                 root_der * weightr21)

    prod_der2 = (root_der * weightr12 +
                 root_der * weightr22)

    # prod_der1 = (root_der * weightr11)
    # prod_der2 = (root_der * weightr12)

    log_prod_der1 = log(prod_der1) if prod_der1 > 0.0 else LOG_ZERO
    log_prod_der2 = log(prod_der2) if prod_der2 > 0.0 else LOG_ZERO

    print('found  prod ders', prod1.log_der, prod2.log_der)
    print('expect prod ders', log_prod_der1, log_prod_der2)

    if IS_LOG_ZERO(log_prod_der1):
        assert IS_LOG_ZERO(prod1.log_der)
    else:
        assert_almost_equal(log_prod_der1, prod1.log_der)
    if IS_LOG_ZERO(log_prod_der2):
        assert IS_LOG_ZERO(prod2.log_der)
    else:
        assert_almost_equal(log_prod_der2, prod2.log_der)

    # sum layer
    sum_der1 = (
        prod_der1 * (weight22 * val2 +
                     weight23 * val3 +
                     weight24 * val4))

    log_sum_der1 = log(sum_der1) if sum_der1 > 0.0 else LOG_ZERO

    sum_der2 = (prod_der1 * (weight11 * val1 +
                             weight12 * val2 +
                             weight13 * val3) +
                prod_der2 * (weight33 * val3 +
                             weight34 * val4 +
                             weight35 * val5))

    log_sum_der2 = log(sum_der2) if sum_der2 > 0.0 else LOG_ZERO

    sum_der3 = (prod_der2 * (weight22 * val2 +
                             weight23 * val3 +
                             weight24 * val4))

    log_sum_der3 = log(sum_der3) if sum_der3 > 0.0 else LOG_ZERO

    print('expected sum ders', log_sum_der1,
          log_sum_der2,
          log_sum_der3)
    print('found    sum ders', sum1.log_der,
          sum2.log_der,
          sum3.log_der)

    if IS_LOG_ZERO(log_sum_der1):
        assert IS_LOG_ZERO(sum1.log_der)
    else:
        assert_almost_equal(log_sum_der1, sum1.log_der)
    if IS_LOG_ZERO(log_sum_der2):
        assert IS_LOG_ZERO(sum2.log_der)
    else:
        assert_almost_equal(log_sum_der2, sum2.log_der)
    if IS_LOG_ZERO(log_sum_der3):
        assert IS_LOG_ZERO(sum3.log_der)
    else:
        assert_almost_equal(log_sum_der3, sum3.log_der)

    # final level, the first one
    try:
        log_der1 = log(sum_der1 * weight11)
    except:
        log_der1 = LOG_ZERO

    try:
        log_der2 = log(sum_der1 * weight12 +
                       sum_der2 * weight22)
    except:
        log_der2 = LOG_ZERO

    try:
        log_der3 = log(sum_der1 * weight13 +
                       sum_der2 * weight23 +
                       sum_der3 * weight33)
    except:
        log_der3 = LOG_ZERO

    try:
        log_der4 = log(sum_der2 * weight24 +
                       sum_der3 * weight34)
    except:
        log_der4 = LOG_ZERO

    try:
        log_der5 = log(sum_der3 * weight35)
    except:
        log_der5 = LOG_ZERO

    # printing, just in case
    print('child log der', node1.log_der, node2.log_der,
          node3.log_der, node4.log_der, node5.log_der)
    print('exact log der', log_der1, log_der2, log_der3,
          log_der4, log_der5)

    if IS_LOG_ZERO(log_der1):
        assert IS_LOG_ZERO(node1.log_der)
    else:
        assert_almost_equal(log_der1, node1.log_der, 15)
    if IS_LOG_ZERO(log_der2):
        assert IS_LOG_ZERO(node2.log_der)
    else:
        assert_almost_equal(log_der2, node2.log_der, 15)
    if IS_LOG_ZERO(log_der3):
        assert IS_LOG_ZERO(node3.log_der)
    else:
        assert_almost_equal(log_der3, node3.log_der, 15)
    if IS_LOG_ZERO(log_der4):
        assert IS_LOG_ZERO(node4.log_der)
    else:
        assert_almost_equal(log_der4, node4.log_der, 15)
    if IS_LOG_ZERO(log_der5):
        assert IS_LOG_ZERO(node5.log_der)
    else:
        assert_almost_equal(log_der5, node5.log_der, 15)
Beispiel #9
0
def test_product_node_backprop():
    # create child nodes
    child1 = Node()
    val1 = 1.
    child1.set_val(val1)

    child2 = Node()
    val2 = 1.
    child2.set_val(val2)

    child3 = Node()
    val3 = 0.0
    child3.set_val(val3)

    # create a product node and add children
    prod_node1 = ProductNode()
    prod_node1.add_child(child1)
    prod_node1.add_child(child2)

    # create a second node on all children
    prod_node2 = ProductNode()
    prod_node2.add_child(child1)
    prod_node2.add_child(child2)
    prod_node2.add_child(child3)

    # eval
    prod_node1.eval()
    prod_node2.eval()

    # set der and backprop
    prod_node_der1 = 1.0
    prod_node1.log_der = log(prod_node_der1)
    prod_node1.backprop()

    prod_node_der2 = 1.0
    prod_node2.log_der = log(prod_node_der2)
    prod_node2.backprop()

    # check for correctness
    log_der1 = log(prod_node_der1 * val2 +
                   prod_node_der2 * val2 * val3)
    log_der2 = log(prod_node_der1 * val1 +
                   prod_node_der2 * val1 * val3)
    log_der3 = log(prod_node_der2 * val1 * val2)

    print('log ders 1:{lgd1} 2:{lgd2} 3:{lgd3}'.format(lgd1=log_der1,
                                                       lgd2=log_der2,
                                                       lgd3=log_der3))

    assert_almost_equal(log_der1, child1.log_der, 15)
    assert_almost_equal(log_der2, child2.log_der, 15)
    assert_almost_equal(log_der3, child3.log_der, 15)

    # setting different values for children
    val1 = 0.
    child1.set_val(val1)

    val2 = 0.
    child2.set_val(val2)

    val3 = 1.
    child3.set_val(val3)

    # eval
    prod_node1.eval()
    prod_node2.eval()

    child1.log_der = LOG_ZERO
    child2.log_der = LOG_ZERO
    child3.log_der = LOG_ZERO

    # set der and backprop
    prod_node_der1 = 0.5
    prod_node1.log_der = log(prod_node_der1)
    prod_node1.backprop()

    prod_node_der2 = 0.1
    prod_node2.log_der = log(prod_node_der2)
    prod_node2.backprop()

    # check for correctness
    try:
        log_der1 = log(prod_node_der1 * val2 +
                       prod_node_der2 * val2 * val3)
    except:
        log_der1 = LOG_ZERO
    try:
        log_der2 = log(prod_node_der1 * val1 +
                       prod_node_der2 * val1 * val3)
    except:
        log_der2 = LOG_ZERO
    try:
        log_der3 = log(prod_node_der2 * val1 * val2)
    except:
        log_der3 = LOG_ZERO

    print('log ders 1:{lgd1} 2:{lgd2} 3:{lgd3}'.format(lgd1=log_der1,
                                                       lgd2=log_der2,
                                                       lgd3=log_der3))
    print('log ders 1:{lgd1} 2:{lgd2} 3:{lgd3}'.format(lgd1=child1.log_der,
                                                       lgd2=child2.log_der,
                                                       lgd3=child3.log_der))

    if IS_LOG_ZERO(log_der1):
        assert IS_LOG_ZERO(child1.log_der)
    else:
        assert_almost_equal(log_der1, child1.log_der, 15)
    if IS_LOG_ZERO(log_der2):
        assert IS_LOG_ZERO(child2.log_der)
    else:
        assert_almost_equal(log_der2, child2.log_der, 15)
    if IS_LOG_ZERO(log_der3):
        assert IS_LOG_ZERO(child3.log_der)
    else:
        assert_almost_equal(log_der3, child3.log_der, 15)

    # setting different values for children
    val1 = 0.
    child1.set_val(val1)

    val2 = 0.2
    child2.set_val(val2)

    val3 = 1.
    child3.set_val(val3)

    # eval
    prod_node1.eval()
    prod_node2.eval()

    child1.log_der = LOG_ZERO
    child2.log_der = LOG_ZERO
    child3.log_der = LOG_ZERO

    # set der and backprop
    prod_node_der1 = 0.5
    prod_node1.log_der = log(prod_node_der1)
    prod_node1.backprop()

    prod_node_der2 = 0.1
    prod_node2.log_der = log(prod_node_der2)
    prod_node2.backprop()

    # check for correctness
    try:
        log_der1 = log(prod_node_der1 * val2 +
                       prod_node_der2 * val2 * val3)
    except:
        log_der1 = LOG_ZERO
    try:
        log_der2 = log(prod_node_der1 * val1 +
                       prod_node_der2 * val1 * val3)
    except:
        log_der2 = LOG_ZERO
    try:
        log_der3 = log(prod_node_der2 * val1 * val2)
    except:
        log_der3 = LOG_ZERO

    print('log ders 1:{lgd1} 2:{lgd2} 3:{lgd3}'.format(lgd1=log_der1,
                                                       lgd2=log_der2,
                                                       lgd3=log_der3))
    print('log ders 1:{lgd1} 2:{lgd2} 3:{lgd3}'.format(lgd1=child1.log_der,
                                                       lgd2=child2.log_der,
                                                       lgd3=child3.log_der))

    if IS_LOG_ZERO(log_der1):
        assert IS_LOG_ZERO(child1.log_der)
    else:
        assert_almost_equal(log_der1, child1.log_der, 15)
    if IS_LOG_ZERO(log_der2):
        assert IS_LOG_ZERO(child2.log_der)
    else:
        assert_almost_equal(log_der2, child2.log_der, 15)
    if IS_LOG_ZERO(log_der3):
        assert IS_LOG_ZERO(child3.log_der)
    else:
        assert_almost_equal(log_der3, child3.log_der, 15)