forked from pmiller10/stock_price
-
Notifications
You must be signed in to change notification settings - Fork 0
/
submitter.py
28 lines (24 loc) · 810 Bytes
/
submitter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import sys
from data.stock import Stock
from predictor import Predictor
from score import auc
from preprocess import Preprocess
submission_number = sys.argv[1]
print submission_number
def submission(ids, preds):
name = "submissions/submission{0}.csv".format(submission_number)
f = open(name, 'w')
if len(ids) != len(preds):
raise Exception("The number of IDs and the number of predictions are different")
string = 'id,prediction\n'
for index,i in enumerate(ids):
string += str(i) + ',' + str(preds[index]) + "\n"
f.write(str(string))
f.close()
data, targets, _ = Stock.train()
holdout_data, ids = Stock.test()
assert len(data) == len(targets)
Predictor.train(data, targets)
preds = Predictor.multi_predict(holdout_data)
print preds[0:50]
submission(ids, preds)