/
logisticregression.py
58 lines (40 loc) · 1.89 KB
/
logisticregression.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from classification import ClassificationModel
from argumentparser import *
class LogisticRegression(ClassificationModel):
def __init__(self, _args):
self.args = _args
def computeModel(XTrain, yTrain, _solver):
from sklearn.linear_model import LogisticRegression
classifier = LogisticRegression(solver=_solver)
classifier.fit(XTrain, yTrain)
return classifier
def compute(self):
import timeit
start = timeit.default_timer()
XTrain, XTest, yTrain, yTest = ClassificationModel.preprocessData(self.args, True)
classifier = LogisticRegression.computeModel(XTrain, yTrain, self.args.solver)
yPred = ClassificationModel.predictModel(classifier, XTest)
confusionMatrix = ClassificationModel.getConfusionMatrix(yPred, yTest)
rocCurve = ClassificationModel.getRocCurve(yPred, yTest)
if(self.args.print_accuracy):
print(confusionMatrix, ClassificationModel.getAccuracy(confusionMatrix))
stop = timeit.default_timer()
return confusionMatrix, rocCurve, ClassificationModel.getAccuracy(confusionMatrix), stop - start, classifier
def computeCrossValidation(self):
from sklearn.model_selection import cross_validate
X, y = ClassificationModel.preprocessDataCrossValidation(self.args, True)
classifier = LogisticRegression.computeModel(X, y, self.args.solver)
cv_results = cross_validate(classifier, X, y, cv=self.args.k_fold_cross_validation)
if(self.args.print_accuracy):
print(cv_results)
return cv_results
if __name__ == "__main__":
parser = ArgumentParser()
parser.setBasicArguments()
parser.setLogisticRegressionArguments()
args = parser.getArguments()
model = LogisticRegression(args)
if(args.cross_validation == False):
model.compute()
else:
model.computeCrossValidation()