def test_node(node, df, Y, regression=False): """ :param node: Node object defined in Stats :param df: The dataframe being used by the tree :param Y: Feature to predict :return: void """ print 'Testing Branching Level : ' + str(node.level) data = node.get_node_data(df) print 'Length of TEST data ' + str(len(data)) + ' len df: ' + str(len(df)) feature = node.label['feature'] label = node.label['criteria'] if feature is not '': print 'feature ' + feature #print df[feature] A_array, B_array = node.split(feature, df[feature], label) print 'Test A : {} B: {}'.format(sum(A_array), sum(B_array)) node.left.set_presence(A_array) node.right.set_presence(B_array) if node.left is not None: test_node(node.left, df, Y, regression) if node.right is not None: test_node(node.right, df, Y, regression) else: predict = node.predict if not regression: error = mystats.binary_error(data, Y, predict) else: error = mystats.compute_MSE(predict, list(data[Y])) node.test_leaf(error)
def branch_node(node, df, threshold, Y, regression=False): """ :param node: Node object defined in Stats :param df: The dataframe being used by the tree :param threshold: max branching depth :param Y: Feature to predict :return: void """ print 'Branching Level : ' + str(node.level) data = node.get_node_data(df) print 'Length of data ' + str(len(data)) + ' len df: ' + str(len(df)) feature, label = mytree.find_best_feature_and_label_for_split(data, Y, regression) print 'feature: {} label: {}'.format(feature, label) if feature is not None and node.level < threshold: A_array, B_array = node.split(feature, df[feature], label) print ' A : {} B: {}'.format(sum(A_array), sum(B_array)) node.add_left(A_array) node.add_right(B_array) branch_node(node.left, df, threshold, Y, regression) branch_node(node.right, df, threshold, Y, regression) else: if not regression: predict = 0 prob = mystats.binary_probability(data, Y) print 'PROBABILITY ' + str(prob) if prob >= .5: predict = 1 error = mystats.binary_error(data, Y, predict) else: print str(feature) +'is fueaturea ' + str(label) + str(node.presence) predict = float(sum(data[Y]))/len(data[Y]) error = mystats.compute_MSE(predict, list(data[Y])) node.leaf(predict, error)