Example #1
0
    sess = tf.Session()
else:
    sess = tf.Session(config=tf.ConfigProto(
        intra_op_parallelism_threads=int(nthr)))
sess.run(tf.global_variables_initializer())

########################################################################
# Load saved weights if any
if niter > 0:
    mfn = wts+"/iter_%06d.model.npz" % niter
    sfn = wts+"/iter_%06d.state.npz" % niter

    ut.mprint("Restoring model from " + mfn )
    ut.loadNet(mfn,model,sess)
    ut.mprint("Restoring state from " + sfn )
    ut.loadAdam(sfn,opt,model.weights,sess)
    ut.mprint("Done!")

else:
    # Load the last model in pretraining
    wcard = os.path.dirname(wts)+"/pretrain/iter_*.model.npz"
    lst=[(l,int(re.match('.*/.*_(\d+)',l).group(1))) for l in glob(wcard)]
    mfn = max(lst, key=lambda x: x[1])[0]
    ut.mprint("Start with pre-trained model " + mfn )
    ut.loadNet(mfn,model,sess)

#########################################################################
# Main Training loop

stop=False
ut.mprint("Starting from Iteration %d" % niter)
Example #2
0
sess.run(tf.global_variables_initializer())

# Load Data File Names
tlist = [f.rstrip('\n') for f in open('data/train.txt').readlines()]
vlist = [f.rstrip('\n') for f in open('data/val.txt').readlines()]

ESIZE=len(tlist)//BSZ
VESIZE=len(vlist)//BSZ

# Setup save/restore
origiter = saver.iter
rs = np.random.RandomState(0)
if origiter > 0:
    ut.loadNet(saver.latest,net,sess)
    if os.path.isfile('wts/opt.npz'):
        ut.loadAdam('wts/opt.npz',opt,net.weights,sess)
    
    for k in range( (origiter+ESIZE-1) // ESIZE):
        idx = rs.permutation(len(tlist))
    ut.mprint("Restored to iteration %d" % origiter)    

    
# Main Training Loop    
niter = origiter    
touts = 0.
while niter < MAXITER+1:
    
    if niter % VALITER == 0:
        vouts = 0.
        for j in range(VALREP):
            off = j % (len(vlist)%BSZ + 1)