-
Notifications
You must be signed in to change notification settings - Fork 1
/
Draw_subplots.py
121 lines (96 loc) · 3.51 KB
/
Draw_subplots.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import pandas as pd
import matplotlib.pylab as plt
def drawSubplots():
linearreg_true, linearreg_pred = read_csv("LinearRegression_results.csv")
rfr_true, rfr_pred = read_csv("RandomForestRegressor_results.csv")
arima_true, arima_pred = read_csv("ARIMA_results.csv")
rnn_true, rnn_pred = read_csv("RNN_results.csv")
size=10
fig = plt.figure()
fig = plt.figure(dpi=100, figsize=(20, 7))
plt.subplot(221)
days = range(len(linearreg_true))
plt.plot(days, linearreg_true,color='r', label='truth sales' )
plt.plot(days, linearreg_pred,color='b', label='pred sales')
plt.yscale('log')
plt.xlabel("days")
plt.ylabel("sales")
plt.legend(loc='upper left', frameon=False)
plt.title('Linear Regression',fontsize=size)
plt.subplot(222)
days = range(len(rfr_true))
plt.plot(days, rfr_true,color='r')
plt.plot(days, rfr_pred,color='b')
plt.yscale('log')
plt.xlabel("days")
plt.ylabel("sales")
# plt.legend(loc='upper left', frameon=False)
plt.title('Random Forest Regressor',fontsize=size)
plt.subplot(223)
days = range(len(arima_true))
plt.plot(days, arima_true,color='r')
plt.plot(days, arima_pred,color='b')
plt.yscale('log')
plt.xlabel("days")
plt.ylabel("sales")
# plt.legend(loc='upper left', frameon=False)
plt.title('ARIMA',fontsize=size)
plt.subplot(224)
days = range(len(rnn_true))
plt.plot(days, rnn_true,color='r' )
plt.plot(days, rnn_pred,color='b' )
plt.yscale('log')
plt.xlabel("days")
plt.ylabel("sales")
# plt.legend(loc='upper left', frameon=False)
plt.title('LSTM',fontsize=size)
plt.savefig("store 285 subplots.png", format='png', bbox_inches='tight', transparent=False)
plt.show()
# plt.subplot
#
# f, axarr = plt.subplots(4, sharey=True)
# f.suptitle('Sales Prediction store 285')
#
# days = range(len(linearreg_true))
# axarr[0].plot(days, linearreg_true,color='r', label='truth sales' )
# axarr[0].plot(days, linearreg_pred,color='b', label='pred sales')
# axarr[0].plt.legend(loc='upper left', frameon=False)
# axarr[0].plt.yscale('log')
# axarr[0].plt.xlabel("days")
# axarr[0].plt.ylabel("sales")
# axarr[0].plt
#
#
#
# days = range(len(rfr_true))
# axarr[1].plot(days, rfr_true, color='r', label='truth sales')
# axarr[1].plot(days, rfr_pred, color='b', label='pred sales')
# axarr[1].plt.legend(loc='upper left', frameon=False)
# axarr[1].plt.yscale('log')
# axarr[1].plt.xlabel("days")
# axarr[1].plt.ylabel("sales")
#
# days = range(len(arima_true))
# axarr[2].plot(days, arima_true, color='r', label='truth sales')
# axarr[2].plot(days, arima_pred, color='b', label='pred sales')
# axarr[2].plt.legend(loc='upper left', frameon=False)
# axarr[2].plt.yscale('log')
# axarr[2].plt.xlabel("days")
# axarr[2].plt.ylabel("sales")
#
# days = range(len(rnn_true))
# axarr[3].plot(days, rnn_true, color='r', label='truth sales')
# axarr[3].plot(days, rnn_pred, color='b', label='pred sales')
# axarr[3].plt.legend(loc='upper left', frameon=False)
# axarr[3].plt.yscale('log')
# axarr[3].plt.xlabel("days")
# axarr[3].plt.ylabel("sales")
#
# f.plt.savefig("store 286 subplot", format='png', bbox_inches='tight', transparent=False)
def read_csv(filename):
store_data = pd.read_csv(filename)
x = store_data.iloc[:,0]
y = store_data.iloc[:,1]
return x,y
if __name__ == '__main__':
drawSubplots()