예제 #1
0
파일: ass4.py 프로젝트: bonnywong/DD2431
def pruneTree(tree, validation):
    run = True
    bestGain = 0
    prunedTrees = d.allPruned(tree)

    while run:
        currentgain = 0
        maxgain = 0
        besttree = 0
        #print("Number of possible prunings: %d" % len(prunedTrees))

        for x in range(0, len(prunedTrees)):
            currentgain = d.check(prunedTrees[x], validation)
            #print("Rate for tree %d: %f " % (x + 1, currentgain))
            if(currentgain > maxgain):
                maxgain = currentgain;
                bestTree = prunedTrees[x]

        prunedTrees = d.allPruned(bestTree)

        if(maxgain > bestGain):
            bestGain = maxgain
        else:
            run = False
            #print("Max accuracy reached. Pruning stopped.")
            #print("Best accuracy: %f" % bestGain);

    return bestTree
예제 #2
0
def prunedTree(training, validation):
    tree = dtree.buildTree(training, m.attributes)
    poss = dtree.allPruned(tree)
    scores = []
    for i in range(len(poss)):
        scores.append(dtree.check(poss[i], validation))
    while max(scores) >= dtree.check(tree, validation):
        tree = poss[scores.index(max(scores))]
        poss = dtree.allPruned(tree)
        scores = []
        for i in range(len(poss)):
            scores.append(dtree.check(poss[i], validation))
    return tree
예제 #3
0
def check_pruning(data_set):
    s_dict = dict()
    t_temp = d.buildTree(data_set.Train, m.attributes)
    prun_set = d.allPruned(t_temp)
    for temp in prun_set:
        s_dict[temp] = (d.check(temp, data_set.Test))
    return key_with_maxval(s_dict)
예제 #4
0
def prun(tree, val):
    candidates = {}
    pruns = dt.allPruned(tree)
    for p in pruns:
        performance = dt.check(p, val)
        candidates[p] = performance
    return candidates
예제 #5
0
def getData1(iterations):
    fraction = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
    error = [0] * 6
    for i in range(6):
        error[i] = [0] * iterations
    #print("\nMonk1")
    for f in range(len(fraction)):
        #print("\nFactor: %.1f" % f)
        for i in range(0, iterations):
            monk1train, monk1val = partition(mdata.monk1, fraction[f])
            monk1tree = dtree.buildTree(monk1train, mdata.attributes)
            while True:
                prunelist = dtree.allPruned(monk1tree)
                temptree = monk1tree
                for x in prunelist:
                    if dtree.check(x, monk1val) >= dtree.check(
                            temptree, monk1val):
                        temptree = x

                if temptree == monk1tree:
                    break
                monk1tree = temptree

            error[f][i] = dtree.check(monk1tree, mdata.monk1test)
    return error
예제 #6
0
def assignment4_p1(data, attributes, fraction):
    trainData, validData = partition(data, fraction)
    dataTree = d.buildTree(trainData, attributes)
    orgErr = 1 - d.check(dataTree, validData)
    print("ORIGINAL ERR", orgErr)
    orgTree = dataTree
    bestPrunedTree = orgTree
    cont = True
    while cont:
        err = orgErr
        bestErrorRate = err
        prunedTrees = d.allPruned(bestPrunedTree)
        print(len(prunedTrees))
        for i in range(0, len(prunedTrees)):
            err = 1 - d.check(prunedTrees[i], validData)
            print(i, err)
            if err < bestErrorRate:
                bestErrorRate = err
                bestPrunedTree = prunedTrees[i]
                print("Best Error Rate:", bestPrunedTree, bestErrorRate)

        if bestErrorRate > orgErr:
            return orgTree
        elif bestPrunedTree == dataTree:
            break
        # else:
        # if bestPrunedTree == prunedTrees:
        # prunedTrees = d.allPruned(bestPrunedTree)

        orgTree = bestPrunedTree
        orgErr = bestErrorRate
예제 #7
0
def pruning(data_set, fraction = 0.6):
    # A function that returns a pruned decision tree from a data set
    data_train, data_val = partition(data_set, fraction)

    # The tree to become pruned
    tree_pruned = dtree.buildTree(data_train, m.attributes)
    err_tree_pru = dtree.check(tree_pruned, data_val)
#    print("Tree before prune:")
#    print(tree_pruned)

    better = True
    while better:
        better = False
        trees_alt = dtree.allPruned(tree_pruned)
        best_prune = None
        err_best = 0

        for alternative in trees_alt:
            err_alternative = dtree.check(alternative, data_val)

            if err_alternative >= err_tree_pru and err_alternative > err_best:
                best_prune = alternative
                err_best = err_alternative
                better = True

        if better:
            tree_pruned = best_prune
            err_tree_pru = err_best

    return tree_pruned
def pruneTree(dataset, testSet):
	
	fractions = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
	errorList = []

	for x in fractions:
		train, val = partition(dataset, x)
		theTree = tree.buildTree(train, data.attributes)

		list_of_trees = tree.allPruned(theTree)


		theBest = 1000
		bestTree = 0

		for t in list_of_trees:
			error = 1 - tree.check(t, val)

			if error < theBest:
				theBest = error
				bestTree = t
		draw.drawTree(bestTree)
		smallest_error_at_fraction = 1 - tree.check(bestTree, testSet)
		errorList.append(smallest_error_at_fraction)

		# print ("smalest error")
		# print (smallest_error_at_fraction)
		# print ("occured at fraction")
		# print (x)

	return errorList
예제 #9
0
def calculate_best(Td,Vd):

    error = -sys.maxsize
    counter = 0
    current_tree = tree.buildTree(Td,m.attributes)
    tr = tree.buildTree(Td,m.attributes)
    tr_pruned = tree.allPruned(tr)
    
    while True:
        counter = 0
        count = len(tr_pruned)
        
        for x in tr_pruned:
            if tree.check(x,Vd) > error:
                error = tree.check(x,Vd)
                current_tree = x
                #print("current tree")
                #print(current_tree)
                #print("error")
                #print(error)
            else:
                counter = counter + 1
        
        if count == counter:
            break
            
        tr = current_tree
    
   # print("Selected tree:")
    #print(tr)
    #print("error:")
    #print(error)
    return error, tr
예제 #10
0
def findBestPrunedTree(originalTrainSet, fraction):
    """ Find the best pruned tree, given a training set and a fraction for partitioning. """
    trainSet, validationSet = partition(originalTrainSet.dataset, fraction)
    tree = d.buildTree(trainSet, m.attributes)

    bestTreeSoFar = tree
    bestPerformanceSoFar = d.check(tree, validationSet)
    print("Pruning " + originalTrainSet.name + " with fraction = " +
          str(fraction) + " and performance on new validation set = " +
          str(bestPerformanceSoFar))

    while (True):
        possibleWaysToPruneTree = d.allPruned(bestTreeSoFar)

        if (len(possibleWaysToPruneTree) == 0):
            print("No more ways to prune tree. Returning.")
            return bestTreeSoFar, bestPerformanceSoFar

        bestPrunedTree, performance = getBestPerformingTree(
            possibleWaysToPruneTree, validationSet)

        if (performance >= bestPerformanceSoFar):
            print("Found pruned tree which performed better: " +
                  str(performance))
            bestTreeSoFar = bestPrunedTree
            bestPerformanceSoFar = performance
        else:
            print("All pruned trees perform worse. Stopping here.")
            return bestTreeSoFar, bestPerformanceSoFar
def prune(dec_tree, val_data):
    #Flag to keep memory of any best tree
    one_better = True

    while one_better:
        #Obtain all the pruned tress
        pruned_trees = allPruned(dec_tree)
        #print("%d pruned tress" % (len(pruned_trees)))
        dec_tree_perf = check(dec_tree, val_data)

        #Set local variables
        one_better = False
        maxPerf = dec_tree_perf

        #Compute performance evaluation and keep the best one
        for tree in pruned_trees:
            tree_perf = check(tree, val_data)
            #print("\t NEW(%f), OLD(%f)" % (tree_perf, maxPerf))
            if tree_perf >= maxPerf:
                maxPerf = tree_perf
                dec_tree = tree
                one_better = True
                #print("\tFound a better one: %f" % (tree_perf))

    return maxPerf, dec_tree
예제 #12
0
def pruneTree(train, validation, acc_desired):

    t = d.buildTree(train, m.attributes)
    accuracy = d.check(t, validation)
    accuracy_p = accuracy
    #print("Starting accuracy:" + str(accuracy))
    temp = t
    tt = 0
    while (tt < acc_desired):
        tt += 1
        temp = t
        tlist = d.allPruned(t)
        accuracy_p = 0
        for i in range(0, len(tlist)):
            #print(i)
            accuracy = d.check(tlist[i], validation)
            #print("Pruned tree no " + str(i) + " accuracy: " + str(accuracy))
            #print(accuracy_p)
            if (accuracy >= accuracy_p):
                accuracy_p = accuracy
                #print("Set new accuracy_p: " + str(accuracy_p))
                t = tlist[i]

        #print(str(acc_prev_tree) + " " + str(accuracy_p))

    if (d.check(temp, validation) > d.check(t, validation)):
        t = temp
    """ 
    print(t)
    print("Final accuracy: " + str(d.check(t, validation)))
    pyqt.drawTree(t) 
    """
    return t
예제 #13
0
def getPrunedChildren(toPrune, bestErrorRate, validData):
    bestPrunedTreesGrandChildren = []
    for bestPrunedTreeIndex in range(0, len(toPrune)):
        # print(toPrune[bestPrunedTreeIndex])
        prunedTreesChildren = []
        prunedTreesChildren = d.allPruned(toPrune[bestPrunedTreeIndex])
        # print(len(prunedTreesChildren))
        notFound = False
        for i in range(0, len(prunedTreesChildren)):
            tempPrunedTreesGrandChildren = []
            err = 1 - d.check(prunedTreesChildren[i], validData)
            # print(i, err)
            if err <= bestErrorRate:
                # bestErrorRate = err
                tempPrunedTreesGrandChildren.append(getPrunedChildren([prunedTreesChildren[i]], err, validData))
            else:
                notFound = True

                # print("Best Error Rate:", prunedTreesChildren[i], bestErrorRate)
                # print(len(tempPrunedTreesGrandChildren))
        if notFound:
            tempPrunedTreesGrandChildren.append(toPrune[bestPrunedTreeIndex])
        bestPrunedTreesGrandChildren += tempPrunedTreesGrandChildren
        # print(len(bestPrunedTreesGrandChildren))
    return bestPrunedTreesGrandChildren
예제 #14
0
def oneprune(tree, valset):
    tree_list = [
        tr for tr in dtree.allPruned(tree)
        if dtree.check(tr, valset) > dtree.check(tree, valset)
    ]
    if len(tree_list) == 0:
        return [tree]
    return [tree for tr in tree_list for tree in oneprune(tr, valset)]
예제 #15
0
def best_pruned(base,valid_set):
	pruned = d.allPruned(base)
	best = (base,d.check(base,valid_set))
	for tree in pruned:
		perf = d.check(tree,valid_set)
		if perf >= best[1]:
			best = (tree, perf)
	return best
예제 #16
0
def pruneNow(tree, data, testData):
    newVal = 1 - d.check(tree, data)

    for prunedTree in d.allPruned(tree):
        val = 1 - d.check(prunedTree, testData)
        if val < newVal:
            newVal = val

    return newVal
예제 #17
0
파일: lab1.py 프로젝트: mkufel/ML
def prune(tree, valSet):
    currentTree = tree
    currentPerf = dt.check(currentTree, valSet)
    pTrees = dt.allPruned(currentTree)
    for pTree in pTrees:
        if (dt.check(pTree, valSet) > currentPerf):
            currentTree = prune(pTree, valSet)
            currentPerf = dt.check(currentTree, valSet)
    return currentTree
예제 #18
0
def getClasification(dataset,fraction):
    monk1train, monk1val = partition(dataset,fraction)
    testTree = tree.buildTree(monk1val,m.attributes)
    prunedTrees = tree.allPruned(testTree)
    pValue = 0
    for pruned in prunedTrees:
        if(tree.check(pruned,monk1train) > pValue):
            bestTree = pruned
            pValue = tree.check(pruned,monk1train)
    return pValue, bestTree
예제 #19
0
def prune_tree(tree, validation):
    pruned_trees = d.allPruned(tree)
    pruned_trees_performance = [0 for x in range(len(pruned_trees))]
    for candidate in pruned_trees:
        index = pruned_trees.index(candidate)
        pruned_trees_performance[index] = d.check(candidate, validation)
    if d.check(tree, validation) <= max(pruned_trees_performance):
        tree = pruned_trees[pruned_trees_performance.index(max(pruned_trees_performance))]
        tree = prune_tree(tree, validation)
    return tree
예제 #20
0
def bestPrunedFromList(tree, validationDataset):
    listOfTrees = dtree.allPruned(tree)
    bestValue = dtree.check(tree, validationDataset)
    bestTree = listOfTrees[len(listOfTrees) - 1]
    for tree in listOfTrees:
        temp = dtree.check(tree, validationDataset)
        if temp > bestValue:
            bestValue = temp
            bestTree = tree
    return bestTree
예제 #21
0
파일: Labb_1.py 프로젝트: jonte450/DD2421
def find_prunned(data_part, f_part):
    monk1train, monkvalue = partition(data_part, f_part)
    dtree = tree.buildTree(monk1train, dataset.attributes)
    prun_list = tree.allPruned(dtree)
    current_correctness = tree.check(dtree, monkvalue)
    for current_tree in prun_list:
        check_correctness = tree.check(current_tree, monkvalue)
        if check_correctness > current_correctness:
            current_correctness = check_correctness
            dtree = current_tree
    return dtree
예제 #22
0
def bestPrunedTree(trainer, validation):
    max = 0
    pruneWays = d.allPruned(trainer)
    for tree in pruneWays:
        current = d.check(tree, validation)
        if (len(pruneWays) == 0):
            print("Prune completed, no more left.")
        if current > max:
            max = current
            max_tree = tree
    return max_tree
예제 #23
0
파일: lab1.py 프로젝트: ViktorCollin/ml13
def prune():
  print "\n------------------------------\nAssignment 4 - Pruning\n------------------------------"
  print "Dataset\t  0.3\t\t  0.4\t\t  0.5\t\t  0.6\t\t  0.7\t\t  0.8"
  partSizes = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
  r = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
  i = 0
  for size in partSizes:  
    for j in range(100):
      training, test = partition(data.monk1, size)
      bestTree = dt.buildTree(training, data.attributes)
      bestClass = dt.check(bestTree, test)
      better = True
      while better:
        better = False
        for subTree in dt.allPruned(bestTree):
          if dt.check(subTree, test) > bestClass:
            bestTree = subTree
            bestClass = dt.check(subTree, test)
            better = True
      r[i] += (1-dt.check(bestTree, data.monk1test))
    i += 1
  print "Monk1\t%0.6f\t%0.6f\t%0.6f\t%0.6f\t%0.6f\t%0.6f\t" % (r[0]/100, r[1]/100, r[2]/100, r[3]/100, r[4]/100, r[5]/100)
  r = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
  i = 0
  for size in partSizes:  
    for j in range(100):
      training, test = partition(data.monk3, size)
      bestTree = dt.buildTree(training, data.attributes)
      bestClass = dt.check(bestTree, test)
      better = True
      while better:
        better = False
        for subTree in dt.allPruned(bestTree):
          if dt.check(subTree, test) >= bestClass:
            bestTree = subTree
            bestClass = dt.check(subTree, test)
            better = True
      r[i] += (1-dt.check(bestTree, data.monk3test))
    i += 1
  print "Monk3\t%0.6f\t%0.6f\t%0.6f\t%0.6f\t%0.6f\t%0.6f\t" % (r[0]/100, r[1]/100, r[2]/100, r[3]/100, r[4]/100, r[5]/100)
예제 #24
0
def getBestTree(bestTree, bestTreeError, monkVal):
    while True:
        trees = d.allPruned(bestTree)
        newCandidate = False
        for tree in trees:
            newError = 1 - d.check(tree, monkVal)
            if newError < bestTreeError:
                bestTree = tree
                bestTreeError = newError
                newCandidate = True
        if not newCandidate:
            break
    return bestTree, bestTreeError
예제 #25
0
def pruning(input_tree,validation):
    error=1
    aux=error
    err=list()
    while aux<=error:
        aux=error
        alt=dtree.allPruned(input_tree)
        for i in range(len(alt)):
            err[i]=1-dtree.check(alt[i], validation)
        error=min(err)
        ind=err.index(min(err))
        input_tree=alt[ind]
    return error
예제 #26
0
def selectBestTree(tree, best_score, dataset):

    P_trees = dtree.allPruned(tree)
    for subtree in P_trees[1:]:
        new_score = dtree.check(subtree, dataset)
        #print(new_score,subtree)
        if new_score >= best_score:
            #print('into backtracking now score=',(new_score))
            tree, best_score = selectBestTree(subtree, new_score, dataset)
            #return selectBestTree(subtree,new_score,dataset)

    #print('out score=',(best_score))
    return tree, best_score
예제 #27
0
파일: a7.py 프로젝트: chjen1994/dd2421_lab1
def checkperformance(tree, monk1val):
    pruned_trees = d.allPruned(tree)
    t1_better_performance = -1
    best_tree = None
    for t in pruned_trees:
        if t1_better_performance < d.check(t, monk1val):
            t1_better_performance = d.check(t, monk1val)
            best_tree = t

    if t1_better_performance >= d.check(tree, monk1val):
        return checkperformance(best_tree, monk1val)

    return tree
예제 #28
0
파일: Pruning.py 프로젝트: dsouzarc/kth
def prune_tree(monkdata_set, num_trials=50):
    """ 
        Randomizes data and then splits into partitions based on partition_fractions
        Creates a tree based on the first partition (training data)
        Prunes that tree multiple times to see effect of pruning and partition on accuracy 
        Returns a dict with partition_fraction mapped to best accuracy list

        :param monkdata_set: monkdata set from monkdata.py
        :param num_trials: number of trials to run

        :returns dict: partition_fraction mapped to a list of tuples
    """

    partition_fractions = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8]

    #Key: partition_fraction. Value: list of max accuracy in the pruning
    partition_accuracy = OrderedDict()

    for i in range(0, num_trials):

        for partition_fraction in partition_fractions:

            monk_training, monk_validation = partition(monkdata_set,
                                                       partition_fraction)
            tree = dtree.buildTree(monk_training, monkdata.attributes)
            accuracy = dtree.check(tree, monk_validation)

            prune_counter = 0
            max_accuracy = accuracy
            max_accuracy_prune = 0

            pruned_trees = dtree.allPruned(tree)

            for pruned_tree in pruned_trees:
                prune_counter += 1
                pruned_accuracy = dtree.check(pruned_tree, monk_validation)

                #Keep track of the largest prune_accuracy and number
                if pruned_accuracy > max_accuracy:
                    max_accuracy = pruned_accuracy
                    max_accuracy_prune = prune_counter

            #If we haven't stored the fraction yet, create a new array
            if not partition_fraction in partition_accuracy:
                partition_accuracy[partition_fraction] = list()

            #Add our most recent trial result there
            prune_result = (max_accuracy_prune, max_accuracy)
            partition_accuracy[partition_fraction].append(prune_result)

    return partition_accuracy
예제 #29
0
def prune(tree, prune_data):
    all_pruned = dtree.allPruned(tree)

    dirty = False

    for pruned in all_pruned:
        if dtree.check(tree, prune_data) < dtree.check(pruned, prune_data):
            dirty = True
            tree = pruned

    if dirty:
        return prune(tree, prune_data)
    else:
        return tree
예제 #30
0
def best_pruned_tree(dataset, fraction):
    train, val = partition(dataset, fraction)
    tree = dt.buildTree(train, m.attributes)
    improved = True
    while improved:
        improved = False
        best_performance = dt.check(tree, val)
        for pruned_tree in dt.allPruned(tree):
            performance = dt.check(pruned_tree, val)
            if performance > best_performance:
                best_performance = performance
                tree = pruned_tree
                improved = True
    return tree
예제 #31
0
def complete_prune(tree, validation):
    better_found = True
    best_prune = tree
    best_perf = dtree.check(tree, validation)
    while better_found:
        better_found = False
        prunes = dtree.allPruned(best_prune)
        for prune_data in prunes:
            performance = dtree.check(prune_data, validation)
            if performance > best_perf:
                best_prune = prune_data
                best_perf = performance
                better_found = True
    return best_prune
예제 #32
0
파일: main.py 프로젝트: Axram/ML_Lab1
def prune(tree, testdata, performance_ref):
        #Prunes tree from given test data
        alternatives = dt.allPruned(tree)
        best_per = 0
        best_tree = None
        for subtree in alternatives:
                performance = dt.check(subtree, testdata)
                if performance > best_per:
                        best_per = performance
                        best_tree = subtree
        if best_per >= performance_ref:
                return prune(best_tree, testdata, performance_ref)
        else:
                return tree
예제 #33
0
def best_pruned_tree(dataset, fraction):
    train, val = partition(dataset, fraction)
    tree = dt.buildTree(train, m.attributes)
    improved = True
    while improved:
        improved = False
        best_performance = dt.check(tree, val)
        for pruned_tree in dt.allPruned(tree):
            performance = dt.check(pruned_tree, val)
            if performance > best_performance:
                best_performance = performance
                tree = pruned_tree
                improved = True
    return tree
예제 #34
0
def pruneRec(tree, monkval):
    treeList = dtree.allPruned(tree)
    best = tree
    for i in range(len(treeList)):
        for j in range(i + 1, len(treeList)):
            k = dtree.check(treeList[i], monkval)
            l = dtree.check(treeList[j], monkval)
            if l > k:
                best = treeList[j]
            else:
                best = treeList[i]
    if best == tree:
        return best
    return pruneRec(best, monkval)
예제 #35
0
def findBestTree(tree, compare, lastBest=0, lastBestTree=None):
	bestTree = lastBestTree
	bestVal = lastBest

	for p in d.allPruned(tree):
		val = d.check(p, compare)
		if val > bestVal:
			bestTree = p
			bestVal = val

	if(bestVal > lastBest):
		return findBestTree(bestTree, compare, bestVal, bestTree)
	else:
		return bestTree
예제 #36
0
def pruneTree(trainSet, fraction):
    monktrain, monkval = partition(trainSet, fraction)
    bestTree = dtree.buildTree(monktrain, m.attributes)
    treePermutations = dtree.allPruned(bestTree)


    bestVal = dtree.check(bestTree, monkval)

    for treeP in treePermutations:
        treePerformance = dtree.check(treeP, monkval)
        if (treePerformance > bestVal):
            bestTree = treeP
            bestVal = treePerformance
    return bestVal, bestTree, monkval
예제 #37
0
def findBestPrune(tree, validationdata):
    prunedtree = d.allPruned(tree)
    besttree = tree
    bestperformance = d.check(besttree, validationdata)
    for candidatetree in prunedtree:
        candidateperformance = d.check(candidatetree, validationdata)
        # just take greater because all prunes returns the original tree as well?
        if (candidateperformance > bestperformance):
            besttree = candidatetree
            bestperformance = candidateperformance
    if besttree == tree:
        return tree
    else:
        return findBestPrune(besttree, validationdata)
예제 #38
0
파일: main.py 프로젝트: perket/DD2421
def assignment_7(monktrain, monkval):
    t = dtree.buildTree(monktrain, m.attributes)
    p1 = performance = dtree.check(t, monkval)
    better_found = True
    while better_found:
        prunes = dtree.allPruned(t)
        better_found = False
        for prune in prunes:
            tmp_performance = dtree.check(prune, monkval)
            if tmp_performance > performance:
                t = prune
                performance = tmp_performance
                better_found = True
    return p1, dtree.check(t, monkval)
예제 #39
0
def find_best_pruned_tree(tree, validate):
    best_perf = d.check(tree, validate)
    forest = d.allPruned(tree)

    temp_tree = None
    best_tree = tree

    for t in forest:
        temp_perf = d.check(t, validate)
        if temp_perf > best_perf:
            best_perf = temp_perf
            best_tree = tree

    return best_tree, best_perf
예제 #40
0
파일: lab.py 프로젝트: jacobhal/ML-DD2421
def prune(currentRatio, tree, validationSet):
    pruningCandidates = dtree.allPruned(tree)
    ratios = list(
        map(lambda lst: dtree.check(lst, validationSet), pruningCandidates))
    #print(ratios)
    maxR = max(ratios, default=currentRatio)
    if maxR != currentRatio:
        maxI = ratios.index(maxR)
    #print("Current is: {:f}".format(currentRatio))
    if currentRatio < maxR:
        #print("Found new max: {:f}".format(maxR))
        return prune(maxR, pruningCandidates[maxI], validationSet)
    else:
        return float(currentRatio)
def prune_tree(tree, validation_set):
    cur_tree = tree
    while 1:
        alternatives = dtree.allPruned(cur_tree)
        best_acc = dtree.check(cur_tree, validation_set)
        best_alt = cur_tree
        for alt in alternatives:
            alt_acc = dtree.check(alt, validation_set)
            if alt_acc >= best_acc:
                best_acc = alt_acc
                best_alt = alt
        if best_alt == cur_tree:
            return cur_tree
        cur_tree = best_alt
예제 #42
0
def prune(pruned_tree, test_tree):
    currentBase = pruned_tree
    oldVal = 0
    maxVal = 1
    while maxVal > oldVal:
        maxVal = dt.check(currentBase, test_tree)
        oldVal = maxVal
        maxTree = currentBase
        for pTree in dt.allPruned(currentBase):
            temp = dt.check(pTree, test_tree)
            if temp > maxVal:
                maxVal = temp
                maxTree = pTree
        currentBase = maxTree
    return maxTree
예제 #43
0
def prune(t, val):
	bestTree = t
	bestPerf = d.check(t, val)
	found = True

	while(found):
		found = False
		trees = d.allPruned(bestTree)
		for tree in trees:
			perf = d.check(tree, val)
			if(perf >= bestPerf):
				bestTree = tree
				bestPerf = perf
				found = True
	return bestTree
예제 #44
0
def findBestPrune(tree, validationSet):
#    print("tree")
#    print(tree)
    current=tree
    while True:
        currentPerformance=dtree.check(current, validationSet)	
        pruned=dtree.allPruned(current)	
        if pruned == ():
            break
#        print("current")
#        print(current)
#        print("pruned trees")
#        print(len(pruned))
        performances=map(lambda t : dtree.check(t, validationSet), pruned)
        best, i=max(izip(performances,count())) 
        # ask which trees we should pick when performance is equal? min depth, min average depth, min no of nodes, order in allPruned
        if best < currentPerformance:
            break
        current = pruned[i]
    return current		 
예제 #45
0
파일: lab1.py 프로젝트: fristedt/maskin
def assignment4helper(dataset, fraction):
    monk1train, monk1val = partition(dataset, fraction)
    tree = d.buildTree(monk1train, m.attributes)

    bestTree = None
    maxVal = -1
    cont = True
    i = 0
    while (cont):
        cont = False
        i += 1
        for t in d.allPruned(tree):
            val = d.check(t, monk1val)
            if (val > maxVal):
                cont = True
                bestTree = t
                maxVal = val
        tree = bestTree
    # print("#iterations: %d" % i)
    return tree
예제 #46
0
def pruning( trainingSet, testSet, fraction ):
  train1, train2 = partition( trainingSet, fraction )

  bestTree = dT.buildTree( train1, m.attributes )
  bestTreePerf = dT.check( bestTree, train2 )
  bestTreeFound = True

  while bestTreeFound == True:
    bestTreeFound = False

    prunedTrees = dT.allPruned( bestTree )

    for candidateTree in prunedTrees:

      if dT.check( candidateTree, train2 ) >= bestTreePerf:
        bestTree = candidateTree
        bestTreePerf = dT.check( candidateTree, train2 )
        bestTreeFound = True

  return dT.check( bestTree, testSet )
예제 #47
0
def findPrunned(t, monk1val1)  : 
               t2=[]
               t2 = d.allPruned(t)
               
               maxi1 = d.check(t,monk1val1)
               maxi2 = maxi1
               
               for s in t2:
                     val = d.check(s,monk1val1) 
                     
                     if val < maxi1 :
                          maxi1 = val
                          answertree = s
               if maxi1 == maxi2 :
                     answertree = t 
                     efficiency.append(maxi1)
                     print  maxi1
                     return maxi1
               else :
                    x =  findPrunned(answertree,monk1val1)     
예제 #48
0
def unzip(values):
    return [list(t) for t in zip(*values)]

fractions = [0.3,0.4,0.5,0.6,0.7,0.8]
series=[]
for pair in setpairs:
    values = []
    for fraction in fractions:
        s = pair[0]
        testdata = pair[1]
        training, validation = partition(s, fraction)
        tree=dtree.buildTree(training, monkdata.attributes)
        keepPruning = True
        while keepPruning:
            alternatives = dtree.allPruned(tree)
            keepPruning = False
            for alternative in alternatives:
                if(dtree.check(alternative,validation) > dtree.check(tree,validation)):
                    tree = alternative
                    keepPruning = True
        error=dtree.check(tree,testdata)
        values.append((fraction,error))
    #convert pairs to two lists [xs, ys]
    data=unzip(values)
    data.append(pair[2])
    series.append(data)

print("Pruned trees:")
printlines(series)
print("")
예제 #49
0
"--Assignment 4"
fractionErrors = []
fractions = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8]

for monk in [m.monk1, m.monk3]:
    tempErrors = []  # temporary errors for the chosen fraction, reset

    for f in fractions:
        m1train, m1val = partition(monk, f)  # create new partitioned datasets
        tree = t.buildTree(m1train, m.attributes)  # create tree with the new datasets

        tempPerformance = t.check(tree, m1val)  # get current performance of validation set
        bestPerformance = 0  # best performance w chosen monk and fraction, set as 0

        while bestPerformance < tempPerformance:  # continue until pruned trees worse than current
            bestPerformance = tempPerformance  # while loop taken, tempPerformance is best so far

            for pTree in t.allPruned(tree):
                prunePerformance = t.check(pTree, m1val)

                if tempPerformance < prunePerformance:
                    tempPerformance = prunePerformance
                    tree = pTree

        tempErrors.append(round(bestPerformance, 5))
    fractionErrors.append(tempErrors)

print(fractionErrors)
print("--------------------------------------")
예제 #50
0
	return ldata[:p], ldata[p:]

splits = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8]

print "-----  Mock1  -----"
for split in splits:
  monk1train, monk1val = part(m.monk1, split)

  bestTree = dt.buildTree(monk1val,m.attributes)
  bestClassification = dt.check(bestTree,monk1train)
  foundBetter = True
  numberOfPrunes = 0

  while foundBetter:
    foundBetter = False
    for subTree in dt.allPruned(bestTree):
      if dt.check(subTree,monk1train) > bestClassification:
        bestClassification = dt.check(subTree,monk1train)
        bestTree = subTree
        foundBetter = True
    if foundBetter:
      numberOfPrunes += 1

  print "Best tree found with split = ", split, ", pruned ", numberOfPrunes, " times"
  print bestTree
  print "%.5f"%bestClassification

print


print "-----  Mock3  -----"