/
train.py
112 lines (91 loc) · 4.55 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#!/usr/bin/python3
import tensorflow as tf
import numpy as np
import pandas as pd
import time, os, sys
import argparse
# User-defined
from network import Network
from diagnostics import Diagnostics
from data import Data
from model import Model
from config import config_train, directories
tf.logging.set_verbosity(tf.logging.ERROR)
def train(config, architecture, args):
print('Architecture: {}'.format(architecture))
start_time = time.time()
global_step, n_checkpoints, v_acc_best = 0, 0, 0.
ckpt = tf.train.get_checkpoint_state(directories.checkpoints)
if args.dataset=='cifar100':
config.n_classes = 100
config.L = args.langevin_iterations
# Build graph
cnn = Model(config, directories, name=args.name, optimizer=args.optimizer)
saver = tf.train.Saver()
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
train_handle = sess.run(cnn.train_iterator.string_handle())
test_handle = sess.run(cnn.test_iterator.string_handle())
if args.restore_last and ckpt.model_checkpoint_path:
# Continue training saved model
saver.restore(sess, ckpt.model_checkpoint_path)
print('{} restored.'.format(ckpt.model_checkpoint_path))
else:
if args.restore_path:
new_saver = tf.train.import_meta_graph('{}.meta'.format(args.restore_path))
new_saver.restore(sess, args.restore_path)
print('{} restored.'.format(args.restore_path))
sess.run(cnn.test_iterator.initializer)
for epoch in range(config.num_epochs):
sess.run(cnn.train_iterator.initializer)
# Run diagnostics
v_acc_best = Diagnostics.run_diagnostics(cnn, config_train, directories, sess, saver, train_handle,
test_handle, start_time, v_acc_best, epoch, args.name)
while True:
try:
# Run SGLD iterations
if args.optimizer=='entropy-sgd':
for l in range(config.L):
sess.run([cnn.sgld_op], feed_dict={cnn.training_phase: True, cnn.handle: train_handle})
# Update weights
sess.run([cnn.train_op, cnn.update_accuracy], feed_dict={cnn.training_phase: True,
cnn.handle: train_handle})
except tf.errors.OutOfRangeError:
print('End of epoch!')
break
except KeyboardInterrupt:
save_path = saver.save(sess, os.path.join(directories.checkpoints,
'cnn_{}_last.ckpt'.format(args.name)), global_step=epoch)
print('Interrupted, model saved to: ', save_path)
sys.exit()
save_path = saver.save(sess, os.path.join(directories.checkpoints,
'cnn_{}_end.ckpt'.format(args.name)),
global_step=epoch)
print("Training Complete. Model saved to file: {} Time elapsed: {:.3f} s".format(save_path, time.time()-start_time))
def main(**kwargs):
parser = argparse.ArgumentParser()
parser.add_argument("-rl", "--restore_last", help="restore last saved model", action="store_true")
parser.add_argument("-r", "--restore_path", help="path to model to be restored", type=str)
parser.add_argument("-opt", "--optimizer", default="entropy-sgd", help="Selected optimizer", type=str,
choices=['entropy-sgd', 'adam', 'momentum', 'sgd'])
parser.add_argument("-n", "--name", default="entropy-sgd", help="Checkpoint/Tensorboard label")
parser.add_argument("-d", "--dataset", default="cifar10", help="Dataset to train on (cifar10 || cifar100)",
type=str, choices=['cifar10', 'cifar100'])
parser.add_argument("-L", "--langevin_iterations", default=20, help="Number of Langevin iterations in inner loop.",
type=int)
args = parser.parse_args()
config = config_train
architecture = 'Layers: {} | Conv dropout: {} | Base LR: {} | SGLD Iterations {} | Epochs: {} | Optimizer: {}'.format(
config.n_layers,
config.conv_keep_prob,
config.learning_rate,
config.L,
config.num_epochs,
args.optimizer
)
Diagnostics.setup_dataset(args.dataset)
# Launch training
train(config_train, architecture, args)
if __name__ == '__main__':
main()