def test_product_layer_create_and_eval(): # creating generic nodes node1 = Node() node2 = Node() node3 = Node() # whose values are val1 = 0.8 val2 = 1. val3 = 0. node1.set_val(val1) node2.set_val(val2) node3.set_val(val3) # creating product nodes prod1 = ProductNode() prod2 = ProductNode() prod3 = ProductNode() # adding children prod1.add_child(node1) prod1.add_child(node2) prod2.add_child(node1) prod2.add_child(node3) prod3.add_child(node2) prod3.add_child(node3) # adding product nodes to layer product_layer = ProductLayer([prod1, prod2, prod3]) # evaluating product_layer.eval() # getting log vals layer_evals = product_layer.node_values() print('layer eval nodes') print(layer_evals) # computing our values prodval1 = val1 * val2 logval1 = log(prodval1) if prodval1 > 0. else LOG_ZERO prodval2 = val1 * val3 logval2 = log(prodval2) if prodval2 > 0. else LOG_ZERO prodval3 = val2 * val3 logval3 = log(prodval3) if prodval3 > 0. else LOG_ZERO logvals = [logval1, logval2, logval3] print('log vals') print(logvals) for logval, eval in zip(logvals, layer_evals): if logval == LOG_ZERO: # for zero log check this way for correctness assert IS_LOG_ZERO(eval) is True else: assert_almost_equal(logval, eval, PRECISION)
def test_prod_layer_backprop(): # input layer made of 5 generic nodes node1 = Node() node2 = Node() node3 = Node() node4 = Node() node5 = Node() input_layer = CategoricalInputLayer([node1, node2, node3, node4, node5]) # top layer made by 3 prod nodes prod1 = ProductNode() prod2 = ProductNode() prod3 = ProductNode() # linking to input nodes prod1.add_child(node1) prod1.add_child(node2) prod1.add_child(node3) prod2.add_child(node2) prod2.add_child(node3) prod2.add_child(node4) prod3.add_child(node3) prod3.add_child(node4) prod3.add_child(node5) prod_layer = ProductLayer([prod1, prod2, prod3]) # setting input values val1 = 0.0 node1.set_val(val1) val2 = 0.5 node2.set_val(val2) val3 = 0.3 node3.set_val(val3) val4 = 1.0 node4.set_val(val4) val5 = 0.0 node5.set_val(val5) print('input', [node.log_val for node in input_layer.nodes()]) # evaluating prod_layer.eval() print('eval\'d layer:', prod_layer.node_values()) # set the parent derivatives prod_der1 = 1.0 prod1.log_der = log(prod_der1) prod_der2 = 1.0 prod2.log_der = log(prod_der2) prod_der3 = 0.0 prod3.log_der = LOG_ZERO # back prop layer wise prod_layer.backprop() # check for correctness try: log_der1 = log(prod_der1 * val2 * val3) except: log_der1 = LOG_ZERO try: log_der2 = log(prod_der1 * val1 * val3 + prod_der2 * val3 * val4) except: log_der2 = LOG_ZERO try: log_der3 = log(prod_der2 * val2 * val4 + prod_der3 * val4 * val5 + prod_der1 * val1 * val2) except: log_der3 = LOG_ZERO try: log_der4 = log(prod_der2 * val2 * val3 + prod_der3 * val3 * val5) except: log_der4 = LOG_ZERO try: log_der5 = log(prod_der3 * val3 * val4) except: log_der5 = LOG_ZERO # printing, just in case print('child log der', node1.log_der, node2.log_der, node3.log_der, node4.log_der, node5.log_der) print('exact log der', log_der1, log_der2, log_der3, log_der4, log_der5) if IS_LOG_ZERO(log_der1): assert IS_LOG_ZERO(node1.log_der) else: assert_almost_equal(log_der1, node1.log_der, 15) if IS_LOG_ZERO(log_der2): assert IS_LOG_ZERO(node2.log_der) else: assert_almost_equal(log_der2, node2.log_der, 15) if IS_LOG_ZERO(log_der3): assert IS_LOG_ZERO(node3.log_der) else: assert_almost_equal(log_der3, node3.log_der, 15) if IS_LOG_ZERO(log_der4): assert IS_LOG_ZERO(node4.log_der) else: assert_almost_equal(log_der4, node4.log_der, 15) if IS_LOG_ZERO(log_der5): assert IS_LOG_ZERO(node5.log_der) else: assert_almost_equal(log_der5, node5.log_der, 15) # resetting derivatives node1.log_der = LOG_ZERO node2.log_der = LOG_ZERO node3.log_der = LOG_ZERO node4.log_der = LOG_ZERO node5.log_der = LOG_ZERO # setting new values as inputs val1 = 0.0 node1.set_val(val1) val2 = 0.0 node2.set_val(val2) val3 = 0.3 node3.set_val(val3) val4 = 1.0 node4.set_val(val4) val5 = 1.0 node5.set_val(val5) # evaluating again prod_layer.eval() print('eval\'d layer:', prod_layer.node_values()) # set the parent derivatives prod_der1 = 1.0 prod1.log_der = log(prod_der1) prod_der2 = 1.0 prod2.log_der = log(prod_der2) prod_der3 = 0.0 prod3.log_der = LOG_ZERO # back prop layer wise prod_layer.backprop() # check for correctness try: log_der1 = log(prod_der1 * val2 * val3) except: log_der1 = LOG_ZERO try: log_der2 = log(prod_der1 * val1 * val3 + prod_der2 * val3 * val4) except: log_der2 = LOG_ZERO try: log_der3 = log(prod_der2 * val2 * val4 + prod_der3 * val4 * val5 + prod_der1 * val1 * val2) except: log_der3 = LOG_ZERO try: log_der4 = log(prod_der2 * val2 * val3 + prod_der3 * val3 * val5) except: log_der4 = LOG_ZERO try: log_der5 = log(prod_der3 * val3 * val4) except: log_der5 = LOG_ZERO # printing, just in case print('child log der', node1.log_der, node2.log_der, node3.log_der, node4.log_der, node5.log_der) print('exact log der', log_der1, log_der2, log_der3, log_der4, log_der5) if IS_LOG_ZERO(log_der1): assert IS_LOG_ZERO(node1.log_der) else: assert_almost_equal(log_der1, node1.log_der, 15) if IS_LOG_ZERO(log_der2): assert IS_LOG_ZERO(node2.log_der) else: assert_almost_equal(log_der2, node2.log_der, 15) if IS_LOG_ZERO(log_der3): assert IS_LOG_ZERO(node3.log_der) else: assert_almost_equal(log_der3, node3.log_der, 15) if IS_LOG_ZERO(log_der4): assert IS_LOG_ZERO(node4.log_der) else: assert_almost_equal(log_der4, node4.log_der, 15) if IS_LOG_ZERO(log_der5): assert IS_LOG_ZERO(node5.log_der) else: assert_almost_equal(log_der5, node5.log_der, 15)