-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
35 lines (26 loc) · 1012 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from sklearn.svm import LinearSVC
from sklearn.externals import joblib
from HOG import HOG
import dataset
import argparse
''' Set up the argument parser which will get the CSV file and location where model
is to be stored'''
argparser = argparse.ArgumentParser()
argparser.add_argument("-d", "--dataset", required = True,
help = "path to the dataset file")
argparser.add_argument("-m", "--model", required = True,
help = "path to where the model will be stored")
args = vars(argparser.parse_args())
(digits, labels) = dataset.load_data(args["dataset"])
hog = HOG(orientations = 18, pixelsPerCell = (10, 10), cellsPerBlock = (1, 1), normalise = True)
data = []
# Add histogram for each digit in a list
for digit in digits:
digit = dataset.deskew(digit)
hist = hog.describe(digit.reshape((28,28)))
data.append(hist)
# Set up and train the model
SVC_model = LinearSVC()
SVC_model.fit(data, labels)
# Save the model to file
joblib.dump(SVC_model, args["model"], compress = 3)