/
train_model.py
71 lines (56 loc) · 2.16 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import pickle
import numpy as np
import argparse
import train_dAs as trainer
def run(file, epochs, hidden_nodes):
print(file, ' ', epochs, ' ', hidden_nodes)
with open(file, 'rb') as f:
patients = pickle.load(f)
print(patients.shape)
X = patients[:, :-1]
y = patients[:, -1].astype(int)
dA = {}
for hn in hidden_nodes:
for e in epochs:
name = str(hn) + '_' + str(e)
dA[name] = trainer.train_da(X, learning_rate=0.1,
coruption_rate=0.2,
batch_size=10,
training_epochs=e,
n_hidden=hn)
save_run(file, dA)
def save_run(file, dA):
for key, value in dA.items():
hn, epoch = key.split('_')
model = file.split('/')[2]
file_name = './data/' + model + '/trained/' + hn + '_' + epoch + '.p'
f = open(file_name, 'wb')
pickle.dump(value, f)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", help="model number (folder name)")
parser.add_argument("--run_name", help="id for run (file name)")
parser.add_argument("--training_epochs", nargs='*', default=[10],
help="list of training epochs to save")
parser.add_argument("--hidden_nodes", nargs='*', default=[2],
help="list of hidden nodes to save")
args = parser.parse_args()
if args.model_name is None:
raise Exception('Model Name must be set')
if args.run_name is None:
raise Exception('Run name must be set')
if args.training_epochs is None:
args.training_epochs = [10]
else:
args.training_epochs = [int(x) for x in args.training_epochs]
if args.hidden_nodes is None:
args.hidden_nodes = [2]
else:
args.hidden_nodes = [int(x) for x in args.hidden_nodes]
file_name = ('./data/' + args.model_name + '/patients/' +
args.run_name + '.p')
print(file_name)
run(file=file_name,
epochs=args.training_epochs,
hidden_nodes=args.hidden_nodes)