def test_sum_node_create_and_eval_keras(): n_trials = 100 for i in range(n_trials): n_children = numpy.random.randint(1, 100) print('n children', n_children) children = [Node() for c in range(n_children)] weights = numpy.random.rand(n_children) weights = weights / weights.sum() # # create sum node and adding children to it sum_node = SumNode() for child, w in zip(children, weights): sum_node.add_child(child, w) child.log_vals = K.placeholder(ndim=2) assert len(sum_node.children) == n_children assert len(sum_node.weights) == n_children assert len(sum_node.log_weights) == n_children print(sum_node) # # evaluating for fake probabilities n_instances = numpy.random.randint(1, 100) print('n instances', n_instances) probs = numpy.random.rand(n_instances, n_children) # .astype(theano.config.floatX) log_probs = numpy.log(probs) log_vals = [] for d in range(n_instances): for c, child in enumerate(children): child.set_val(probs[d, c]) sum_node.eval() print('sum node eval') print(sum_node.log_val) log_vals.append(sum_node.log_val) # # now theano sum_node.build_k() eval_sum_node_f = K.function(inputs=[c.log_vals for c in children], outputs=[sum_node.log_vals]) keras_log_vals = eval_sum_node_f([log_probs[:, c].reshape(log_probs.shape[0], 1) for c in range(n_children)])[0] print(keras_log_vals) assert_array_almost_equal(numpy.array(log_vals).reshape(log_probs.shape[0], 1), keras_log_vals, decimal=4)
def test_sum_node_create_and_eval(): # create child nodes child1 = Node() val1 = 1. child1.set_val(val1) child2 = Node() val2 = 1. child2.set_val(val2) # create sum node and adding children to it sum_node = SumNode() weight1 = 0.8 weight2 = 0.2 sum_node.add_child(child1, weight1) sum_node.add_child(child2, weight2) assert len(sum_node.children) == 2 assert len(sum_node.weights) == 2 assert len(sum_node.log_weights) == 2 log_weights = [log(weight1), log(weight2)] assert log_weights == sum_node.log_weights print(sum_node) # evaluating sum_node.eval() print(sum_node.log_val) assert_almost_equal(sum_node.log_val, log(val1 * weight1 + val2 * weight2), places=15) # changing values 1,0 val1 = 1. child1.set_val(val1) val2 = 0. child2.set_val(val2) # evaluating sum_node.eval() print(sum_node.log_val) assert_almost_equal(sum_node.log_val, log(val1 * weight1 + val2 * weight2), places=15) # changing values 0,0 -> LOG_ZERO val1 = 0. child1.set_val(val1) val2 = 0. child2.set_val(val2) # evaluating sum_node.eval() print(sum_node.log_val) assert_almost_equal(sum_node.log_val, LOG_ZERO, places=15)
def test_sum_node_backprop(): # create child nodes child1 = Node() val1 = 1. child1.set_val(val1) child2 = Node() val2 = 1. child2.set_val(val2) # create sum node and adding children to it sum_node1 = SumNode() weight11 = 0.8 weight12 = 0.2 sum_node1.add_child(child1, weight11) sum_node1.add_child(child2, weight12) # adding a coparent sum_node2 = SumNode() weight21 = 0.6 weight22 = 0.4 sum_node2.add_child(child1, weight21) sum_node2.add_child(child2, weight22) # evaluating sum_node1.eval() sum_node2.eval() # setting the log derivatives to the parents sum_node_der1 = 1.0 sum_node1.log_der = log(sum_node_der1) sum_node1.backprop() sum_node_der2 = 1.0 sum_node2.log_der = log(sum_node_der2) sum_node2.backprop() # checking for correctness log_der1 = log(weight11 * sum_node_der1 + weight21 * sum_node_der2) log_der2 = log(weight12 * sum_node_der1 + weight22 * sum_node_der2) print('log ders 1:{lgd1} 2:{lgd2}'.format(lgd1=log_der1, lgd2=log_der2)) assert_almost_equal(log_der1, child1.log_der, 15) assert_almost_equal(log_der2, child2.log_der, 15) # resetting child1.log_der = LOG_ZERO child2.log_der = LOG_ZERO # now changing the initial der values sum_node_der1 = 0.5 sum_node1.log_der = log(sum_node_der1) sum_node1.backprop() sum_node_der2 = 0.0 sum_node2.log_der = LOG_ZERO sum_node2.backprop() # checking for correctness log_der1 = log(weight11 * sum_node_der1 + weight21 * sum_node_der2) log_der2 = log(weight12 * sum_node_der1 + weight22 * sum_node_der2) print('log ders 1:{lgd1} 2:{lgd2}'.format(lgd1=log_der1, lgd2=log_der2)) assert_almost_equal(log_der1, child1.log_der, 15) assert_almost_equal(log_der2, child2.log_der, 15)