-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
executable file
·63 lines (47 loc) · 1.98 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#!/usr/bin/python3.5
import csv
from sklearn.linear_model import LogisticRegression
from Mail import Mail
junk_training_set = []
desired_training_set = []
test_set = []
def init_training_set():
with open('res/desired/index.csv') as csvfile:
desired_mail_reader = csv.reader(csvfile, delimiter=',', quotechar='\"')
for row in desired_mail_reader:
desired_training_set.append(Mail(row, False))
with open('res/junk/index.csv') as csvfile:
junk_mail_reader = csv.reader(csvfile, delimiter=',', quotechar='\"')
for row in junk_mail_reader:
junk_training_set.append(Mail(row, True))
def init_logistic_regression():
lr = LogisticRegression()
jts_arr = [[jts.word_count, jts.sign_count] for jts in junk_training_set]
dts_arr = [[dts.word_count, dts.sign_count] for dts in desired_training_set]
training_set = jts_arr + dts_arr
target_set = ([0] * len(junk_training_set) + [1] * len(desired_training_set))
lr.fit(training_set, target_set)
return lr
def predict_test_mails():
lr = init_logistic_regression()
for mail in test_set:
mail.is_junk_predicted = lr.predict([mail.word_count, mail.sign_count])
def check_prediction_result():
positive_prediction_count = 0
for mail in test_set:
if mail.is_junk == mail.is_junk_predicted:
positive_prediction_count += 1
print("Skuteczność wynosi: ", round(positive_prediction_count / len(test_set), 4) * 100, "%")
def init_test_set():
with open('test/desired/index.csv') as csvfile:
desired_mail_reader = csv.reader(csvfile, delimiter=',', quotechar='\"')
for row in desired_mail_reader:
test_set.append(Mail(row))
with open('test/junk/index.csv') as csvfile:
junk_mail_reader = csv.reader(csvfile, delimiter=',', quotechar='\"')
for row in junk_mail_reader:
test_set.append(Mail(row, True))
init_training_set()
init_test_set()
predict_test_mails()
check_prediction_result()