예제 #1
0
def main():
    print('Most common label in training data: ',
          dt.returnMostCommon(testData, 'label'), '\n')

    print('Entropy of the training data: ', dt.calculateTotalEntropy(testData),
          '\n')

    decisionTree = dt.id3(trainingData, trainingData.attributes,
                          testData.get_column('label'))
    print('Best feature: ', decisionTree.getAttributeName())
    print('    ...and its information gain: ',
          decisionTree.getInformationGain(), '\n')

    trainingError, trainingDepth = dt.getOverallError(trainingData,
                                                      decisionTree)
    print('Accuracy on the training set: ', 100 - trainingError, '\n')

    testError, testDepth = dt.getOverallError(testData, decisionTree)
    print('Accuracy on the test set: ', 100 - testError, '\n')

    print('////////// CROSS VALIDATION //////////')
    accuracies = crossValidation()
    bestDepth = getMaxAccuracyDepth(accuracies)
    print('Best depth: ', bestDepth)

    depthDecisionTree = dt.id3(trainingData, trainingData.attributes,
                               trainingData.get_column('label'))
    depthLimitTree = dt.limitTreeDepth(depthDecisionTree, bestDepth)
    err, depth = dt.getOverallError(testData, depthLimitTree)
    print('Accuracy on test set using the best depth: ', 100 - err, '\n')
예제 #2
0
def main():

    print("\nFull Decision Tree: ")
    data_obj = get_data_obj('data/data_semeion/hand_data_train.csv')
    id3_tree = id3.id3(data_obj, data_obj.attributes,
                       data_obj.get_column('label'))
    print(id3_tree)
    print('root node: ', id3_tree.getAttributeName())

    error, depth = id3.getOverallError(data_obj, id3_tree)
    print("Error on training data: {}%; Depth: {}".format(error, depth))

    data_obj_test = get_data_obj('data/data_semeion/hand_data_test.csv')

    error, depth = id3.getOverallError(data_obj_test, id3_tree)
    print("    Error on test data: {}%; Depth: {}".format(error, depth))

    print("\nTree with Max Depth 5")

    max_depth = 5
    pruned_tree = id3.limitTreeDepth(id3_tree, max_depth)

    error, depth = id3.getOverallError(data_obj_test, pruned_tree)
    print("    Error on test data: {}%; Depth: {}".format(error, depth))
예제 #3
0
def crossValidation():

    filenames, crossTrainData, crossTestData = crossValidationSetup()

    allAvgAccuracy = []
    for max_depth in depths:
        accuracies = []
        print('Testing Depth Limit: ', max_depth)

        for i in range(len(filenames)):
            crossTree = dt.id3(crossTrainData[i], crossTrainData[i].attributes,
                               crossTrainData[i].get_column('label'))
            crossTreeLimit = dt.limitTreeDepth(crossTree, max_depth)

            error, depth = dt.getOverallError(crossTestData[i], crossTreeLimit)
            accuracies.append(100.0 - error)

        avgAccuracy = statistics.mean(accuracies)
        allAvgAccuracy.append(avgAccuracy)
        print("Average accuracy: ", avgAccuracy)
        print("Standard deviation: ", statistics.stdev(accuracies), '\n')

    return allAvgAccuracy
예제 #4
0
 def vote(self, dataObject, tempTree):
     _, _, allPredictions = id3.getOverallError(dataObject, tempTree)
     self.overallVotes.append(allPredictions)