-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_DNN.py
128 lines (106 loc) · 4.15 KB
/
train_DNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy as np
import sys
from keras import Sequential
from keras.layers import LSTM, Dense, Dropout, Conv2D, Flatten, \
BatchNormalization, Activation, MaxPooling2D
from keras.utils import np_utils
from tqdm import tqdm
from keras.models import load_model
from utilities import get_data, class_labels
import numpy as np
import scipy.io.wavfile as wav
import os
import speechpy
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.model_selection import train_test_split
models = ["CNN", "LSTM"]
def get_model(model_name, input_shape):
"""
Generate the required model and return it
:return: Model created
"""
# Models are inspired from
# CNN - https://yashk2810.github.io/Applying-Convolutional-Neural-Network-on-the-MNIST-dataset/
# LSTM - https://github.com/harry-7/Deep-Sentiment-Analysis/blob/master/code/generatePureLSTM.py
model = Sequential()
if model_name == 'CNN':
print("model is CNN")
model.add(Conv2D(8, (13, 13),input_shape=(input_shape[0], input_shape[1], 1)))
model.add(BatchNormalization(axis=-1))
model.add(Activation('relu'))
model.add(Conv2D(8, (13, 13)))
model.add(BatchNormalization(axis=-1))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 1)))
model.add(Conv2D(8, (13, 13)))
model.add(BatchNormalization(axis=-1))
model.add(Activation('relu'))
model.add(Conv2D(8, (2, 2)))
model.add(BatchNormalization(axis=-1))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 1)))
model.add(Flatten())
model.add(Dense(64))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Dropout(0.2))
elif model_name == 'LSTM':
print("model is LSTM")
model.add(LSTM(128,input_shape=(input_shape[0], input_shape[1])))
print(input_shape)
model.add(Dropout(0.3))
# model.add(Dense(128, activation='relu'))
model.add(Dense(32, activation='relu'))
model.add(Dense(16, activation='tanh'))
model.add(Dense(len(class_labels), activation='softmax'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
return model
def evaluateModel(model):
"""
Train the model and evaluate it
:param model: model to be evaluted
"""
# Train the epochs
best_acc = 0
global x_train, y_train, x_test, y_test
for i in tqdm(range(50)):
# Shuffle the data for each epoch in unison inspired from https://stackoverflow.com/a/4602224
p = np.random.permutation(len(x_train))
x_train = x_train[p]
y_train = y_train[p]
model.fit(x_train, y_train, batch_size=32, epochs=10)
loss, acc = model.evaluate(x_test, y_test)
if acc > best_acc:
print ('Updated best accuracy', acc)
best_acc = acc
model.save_weights(best_model_path)
model.load_weights(best_model_path)
print ('Accuracy = ', model.evaluate(x_test, y_test)[1])
if __name__ == "__main__":
if len(sys.argv) != 2:
sys.stderr.write('Invalid arguments\n')
sys.stderr.write('Usage python2 train_DNN.py <model_number>\n')
sys.stderr.write('1 - CNN\n')
sys.stderr.write('2 - LSTM\n')
sys.exit(-1)
n = int(sys.argv[1]) - 1
print ('model given', models[n])
# Read data
global x_train, y_train, x_test, y_test
x_train, x_test, y_train, y_test = get_data(flatten=False)
y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
if n == 0:
# Model is CNN so have to reshape the data
in_shape = x_train[0].shape
print(x_train.shape)
print(in_shape)
x_train = x_train.reshape(x_train.shape[0], in_shape[0], in_shape[1], 1)
x_test = x_test.reshape(x_test.shape[0], in_shape[0], in_shape[1], 1)
elif n > len(models):
sys.stderr.write('Model Not Implemented yet')
sys.exit(-1)
model = get_model(models[n], x_train[0].shape)
global best_model_path
best_model_path = '../models/best_model_2.h5'
evaluateModel(model)