forked from jan-timpe/Tweet-sentiment-analysis
/
spark.py
50 lines (40 loc) · 1.26 KB
/
spark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import classifier
import numpy as np
from pyspark import SparkContext
from pyspark.mllib.classification import NaiveBayes, NaiveBayesModel, LabeledPoint
from pyspark.mllib.util import MLUtils
import shutil
def preprocess(sc, data, labels=None):
data = classifier.tolibsvm(data)
points = []
for i in range(len(data)):
wordarr = data[i]
label = 0
if labels:
label = labels[i]
point = LabeledPoint(label, wordarr)
points.append(point)
rdd = sc.parallelize(points)
return rdd
def traintestsplit(data):
return data.randomSplit([0.6, 0.4])
def context(appname):
sc = SparkContext(appName=appname)
return sc
def train(data):
model = NaiveBayes.train(data, 1.0)
return model
def test(model, data):
pred = data.map(lambda p: (model.predict(p.features), p.label))
acc = 1.0 * pred.filter(lambda pl: pl[0] == pl[1]).count() / data.count()
return acc, model
def save(model, sc, filename):
shutil.rmtree(filename, ignore_errors=True)
model.save(sc, filename)
def load(sc, filename):
model = NaiveBayesModel.load(sc, filename)
return sc, model
def predict(sc, model, data):
data = preprocess(sc, data)
pred = data.map(lambda p: model.predict(p.features))
pred.show()