forked from ayanc/fdscs
/
train.py
executable file
·118 lines (93 loc) · 2.91 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#!/usr/bin/env python3
#--Ayan Chakrabarti <ayan@wustl.edu>
import tensorflow as tf
import numpy as np
import utils as ut
import model
import data
import dops
import os.path
import sys
# Params
BSZ = 4
LR = 1e-4
WD = 1e-5
MAXITER = 500e3
SAVEITER = 1e4
DISPITER = 10
VALITER = 1000
VALREP = 2
saver = ut.ckpter('wts/model*.npz')
if saver.iter >= MAXITER:
MAXITER=550e3
LR = 1e-5
if saver.iter >= MAXITER:
MAXITER=600e3
LR = 1e-6
#### Build Graph
# Build phase2
d = data.dataset(BSZ)
net = model.Net()
output = net.predict(d.limgs, d.cv, d.lrl)
tloss, loss, l1, pc, pc3 = dops.metrics(output,d.disp,d.mask)
vals = [loss,pc,l1,pc3]
tnms = ['loss.t','pc.t','L1.t','pc3.t']
vnms = ['loss.v','pc.v','L1.v','pc3.v']
opt = tf.train.AdamOptimizer(LR)
tstep = opt.minimize(tloss+WD*net.wd,var_list=list(net.weights.values()))
sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=4))
sess.run(tf.global_variables_initializer())
# Load Data File Names
tlist = [f.rstrip('\n') for f in open('data/train.txt').readlines()]
vlist = [f.rstrip('\n') for f in open('data/val.txt').readlines()]
ESIZE=len(tlist)//BSZ
VESIZE=len(vlist)//BSZ
# Setup save/restore
origiter = saver.iter
rs = np.random.RandomState(0)
if origiter > 0:
ut.loadNet(saver.latest,net,sess)
if os.path.isfile('wts/opt.npz'):
ut.loadAdam('wts/opt.npz',opt,net.weights,sess)
for k in range( (origiter+ESIZE-1) // ESIZE):
idx = rs.permutation(len(tlist))
ut.mprint("Restored to iteration %d" % origiter)
# Main Training Loop
niter = origiter
touts = 0.
while niter < MAXITER+1:
if niter % VALITER == 0:
vouts = 0.
for j in range(VALREP):
off = j % (len(vlist)%BSZ + 1)
for b in range(VESIZE):
blst = vlist[(b*BSZ+off):((b+1)*BSZ+off)]
outs = sess.run(vals,feed_dict=d.fdict(blst))
vouts = vouts + np.float32(outs)
vouts = vouts / np.float32(VESIZE*VALREP)
ut.vprint(niter,vnms,vouts)
if niter == MAXITER:
break
if niter % ESIZE == 0:
idx = rs.permutation(len(tlist))
blst = [tlist[idx[(niter%ESIZE)*BSZ+b]] for b in range(BSZ)]
outs,_ = sess.run([vals,tstep],feed_dict=d.fdict(blst))
niter = niter+1
touts = touts+np.float32(outs)
if niter % SAVEITER == 0:
ut.saveNet('wts/model_%d.npz'%niter,net,sess)
saver.clean(every=SAVEITER,last=1)
ut.mprint('Saved Model')
if niter % DISPITER == 0:
touts = touts/np.float32(DISPITER)
ut.vprint(niter,['lr']+tnms,[LR]+list(touts))
touts = 0.
if ut.stop:
break
if niter > saver.iter:
ut.saveNet('wts/model_%d.npz'%niter,net,sess)
saver.clean(every=SAVEITER,last=1)
ut.mprint('Saved Model')
if niter > origiter:
ut.saveAdam('wts/opt.npz',opt,net.weights,sess)
ut.mprint("Saved Optimizer.")