/
main.py
50 lines (37 loc) · 1.46 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import RandomForest
import DecisionTree
import Evaluation
import Utils
import os
import sys
reload(sys)
sys.setdefaultencoding('utf-8')
from pyspark import SparkConf, SparkContext, mllib
from pyspark.mllib.util import MLUtils
def run(searchForOptimal, basepath, filepath):
sc = buildContext()
trainingData, testData = loadData(sc, basepath, filepath)
if searchForOptimal:
optimalRandomForestModel = RandomForest.trainOptimalModel(trainingData, testData)
Evaluation.evaluate(optimalRandomForestModel, testData, logMessage=True)
optimalDecisionTreeModel = DecisionTree.trainOptimalModel(trainingData, testData)
Evaluation.evaluate(optimalDecisionTreeModel, testData, logMessage=True)
else:
randomForestModel = RandomForest.trainModel(trainingData)
Evaluation.evaluate(randomForestModel, testData, logMessage=True)
decisionTreeModel = DecisionTree.trainModel(trainingData)
Evaluation.evaluate(decisionTreeModel, testData, logMessage=True)
def buildContext():
conf = SparkConf().setAppName('ClassificationModel')
print "\nBuild context finished"
return SparkContext(conf = conf)
def loadData(sc, basepath, filepath):
data = MLUtils.loadLibSVMFile(sc, os.path.join(basepath, filepath))
trainingData, testData = data.randomSplit([0.7,0.3])
print '\nLoad data finished'
return trainingData, testData
if __name__ == '__main__':
basepath = '/home/yifei/TestData/data'
filepath = 'a9a_data.txt'
searchForOptimal = True
run(searchForOptimal, basepath, filepath)