예제 #1
0
def main(_):
    start = time.clock()

    if args.dataset == 'MSRA':
        di = MSRA15Importer(args.data_root, cacheDir='./cache/MSRA/',refineNet=None)
        Seq_all = []
        MID = args.test_sub
        for seq in range(9):     
            if seq == MID:
                Seq_train_ = di.loadSequence('P{}'.format(seq), rng=rng, shuffle=False, docom=False, cube=(175, 175, 175))
            else:
                Seq_train_ = di.loadSequence('P{}'.format(seq), rng=rng, shuffle=False, docom=False, cube=None)
            Seq_all.append(Seq_train_)

        Seq_test_raw = Seq_all.pop(MID)
        Seq_test = Seq_test_raw.data
        Seq_train = [seq_data for seq_ in Seq_all for seq_data in seq_.data]

        train_num = len(Seq_train)
        print ('loaded over with %d train samples'%train_num)       
        imgs = np.asarray([d.dpt.copy() for d in Seq_train], 'float32')
        gt3Dcrops = np.asarray([d.gt3Dcrop for d in Seq_train], dtype='float32')
        M = np.asarray([d.T for d in Seq_train], dtype='float32')
        com2D = np.asarray([d.com2D for d in Seq_train], 'float32')
        cube = np.asarray([d.cube for d in Seq_train], 'float32')
        # uv_crop = np.asarray([d.gtcrop for d in Seq_train], dtype='float32')[:, :, 0:-1]
        del Seq_train

        train_stream = MultiDataStream([imgs, gt3Dcrops, M, com2D, cube])
    else:
        raise ValueError('error dataset %s'%args.dataset)
    
    test_num=len(Seq_test)
    print ('loaded over with %d test samples'%test_num) 
    test_gt3Dcrops = np.asarray([d.gt3Dcrop for d in Seq_test], dtype='float32')
    test_M = np.asarray([d.T for d in Seq_test], dtype='float32')
    # test_com2D = np.asarray([d.com2D for d in Seq_test], 'float32')  
    # test_uv_crop = np.asarray([d.gtcrop for d in Seq_test], dtype='float32')[:, :, 0:-1]
    test_uv = np.asarray([d.gtorig for d in Seq_test], 'float32')[:, :, 0:-1]
    test_com3D = np.asarray([d.com3D for d in Seq_test], 'float32') 
    test_cube = np.asarray([d.cube for d in Seq_test], 'float32')
    test_imgs = np.asarray([d.dpt.copy() for d in Seq_test], 'float32')
    test_data=np.ones_like(test_imgs)
    for it in range(test_num):
        test_data[it]=norm_dm(test_imgs[it], test_com3D[it], test_cube[it])
    del Seq_test    
    test_stream = MultiDataStream([test_data, test_gt3Dcrops, test_M, test_com3D, test_uv, test_cube])
    clip_index = np.int(np.floor(test_num/args.batch_size)) * args.batch_size
    extra_test_data = [test_data[clip_index:], test_gt3Dcrops[clip_index:], test_M[clip_index:], 
                       test_com3D[clip_index:], test_uv[clip_index:], test_cube[clip_index:]]

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
    tf.set_random_seed(1)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        model = Model(sess, args)
        model.train(args, train_stream, test_stream) if args.phase == 'train' \
            else model.test(args, test_stream, extra_test_data=None)
        end = time.clock()
        print ('running time: %f s'%(end-start))
예제 #2
0
    def getPCL(self, dpt, T):
        """
        Get pointcloud from frame
        :param dpt: depth image
        :param T: 2D transformation of crop
        """

        return MSRA15Importer.depthToPCL(dpt, T)
    def __init__(self, imgSeqs=None, basepath=None, localCache=True):
        """
        constructor
        """
        super(MSRA15Dataset, self).__init__(imgSeqs, localCache)
        if basepath is None:
            basepath = '../../data/MSRA15/'

        self.lmi = MSRA15Importer(basepath)
from util.handdetector import HandDetector
import numpy as np
from netlib.basemodel import basenet2
from data.transformations import transformPoints2D
import argparse

parser = argparse.ArgumentParser(description='set test subject')
parser.add_argument('--test-sub', type=int, default=None)
args = parser.parse_args()

rng = np.random.RandomState(23455)
import tensorflow as tf

train_root = '/content/cvpr15_MSRAHandGestureDB'
shuffle = False
di = MSRA15Importer(train_root, cacheDir='../../cache/MSRA', refineNet=None)

Seq_all = []
MID = args.test_sub
for seq in range(9):
    shuffle = True
    if seq == MID:
        shuffle = False
        Seq_train_ = di.loadSequence('P{}'.format(seq),
                                     rng=rng,
                                     shuffle=shuffle,
                                     docom=True,
                                     cube=(175, 175, 175))
    else:
        Seq_train_ = di.loadSequence('P{}'.format(seq),
                                     rng=rng,
예제 #5
0
from data.dataset import MSRA15Dataset
from util.handpose_evaluation import MSRAHandposeEvaluation
from util.helpers import shuffle_many_inplace

if __name__ == '__main__':

    eval_prefix = 'MSRA15_COM_AUGMENT'
    if not os.path.exists('./eval/'+eval_prefix+'/'):
        os.makedirs('./eval/'+eval_prefix+'/')

    rng = numpy.random.RandomState(23455)

    print("create data")
    aug_modes = ['com', 'rot', 'none']  # 'sc',

    di = MSRA15Importer('../data/MSRA15/')
    Seq0_1 = di.loadSequence('P0', shuffle=True, rng=rng, docom=False)
    Seq0_1 = Seq0_1._replace(name='P0_gt')
    Seq0_2 = di.loadSequence('P0', shuffle=True, rng=rng, docom=True)
    Seq0_2 = Seq0_2._replace(name='P0_com')
    Seq1_1 = di.loadSequence('P1', shuffle=True, rng=rng, docom=False)
    Seq1_1 = Seq1_1._replace(name='P1_gt')
    Seq1_2 = di.loadSequence('P1', shuffle=True, rng=rng, docom=True)
    Seq1_2 = Seq1_2._replace(name='P1_com')
    Seq2_1 = di.loadSequence('P2', shuffle=True, rng=rng, docom=False)
    Seq2_1 = Seq2_1._replace(name='P2_gt')
    Seq2_2 = di.loadSequence('P2', shuffle=True, rng=rng, docom=True)
    Seq2_2 = Seq2_2._replace(name='P2_com')
    Seq3_1 = di.loadSequence('P3', shuffle=True, rng=rng, docom=False)
    Seq3_1 = Seq3_1._replace(name='P3_gt')
    Seq3_2 = di.loadSequence('P3', shuffle=True, rng=rng, docom=True)
예제 #6
0
        print('Start frame index can be specified: -s <start_idx> or enter now:')
        start_idx = input().lower()
        if len(start_idx.strip()) == 0:
            start_idx = 0
        else:
            start_idx = int(start_idx)
    else:
        print('Start frame index is {}'.format(start_idx))

    rng = numpy.random.RandomState(23455)

    # subset to label
    subset_idxs = []

    if person == 'P0':
        di = MSRA15Importer('/home/boonyew/Documents/semi-auto-anno-master/semi-auto-anno/data/msra/', useCache=False)
        Seq2 = di.loadSequence(person, shuffle=False)
        hc = MSRAHandConstraints([Seq2.name])
        hpe = MSRAHandposeEvaluation([j.gt3Dorig for j in Seq2.data], [j.gt3Dorig for j in Seq2.data])
        for idx, seq in enumerate(Seq2.data):
            ed = {'vis': [], 'pb': {'pb': [], 'pbp': []}}
            Seq2.data[idx] = seq._replace(gtorig=numpy.zeros_like(seq.gtorig), extraData=ed)

        # common subset for all
        subset_idxs = [16, 21, 26, 29, 45, 49, 52, 54, 58, 104, 108, 114, 138, 144, 148, 170, 175, 178, 210, 214, 217, 231, 237, 249, 252, 259, 264, 283, 287, 296, 307, 345, 370, 381, 384, 386, 405, 412, 423, 429, 436, 458, 465, 469, 490, 498, 505, 526, 530, 533, 537, 546, 553, 576, 607, 612, 624, 631, 657, 667, 669, 673, 685, 697, 704, 735, 742, 751, 765, 781, 784, 789, 793, 801, 805, 816, 820, 827, 830, 874, 886, 888, 893, 896, 899, 911, 923, 934, 962, 969, 983, 1023, 1027, 1029, 1034, 1046, 1054, 1057, 1070, 1075, 1085, 1093, 1098, 1110, 1114, 1134, 1138, 1146, 1173, 1181, 1184, 1188, 1191, 1194, 1208, 1213, 1221, 1224, 1228, 1241, 1248, 1251, 1255, 1262, 1267, 1274, 1286, 1295, 1308, 1312, 1335, 1341, 1349, 1353, 1383, 1386, 1389, 1410, 1414, 1422, 1432, 1449, 1452, 1455, 1465, 1473, 1477, 1489, 1504, 1523, 1532, 1542, 1550, 1552, 1571, 1580, 1586, 1591, 1609, 1613, 1617, 1628, 1632, 1644, 1653, 1656, 1688, 1694, 1695, 1698, 1709, 1713, 1725, 1745, 1752, 1756, 1762, 1772, 1778, 1795, 1812, 1814, 1817, 1830, 1833, 1848, 1853, 1858, 1864, 1869, 1873, 1887, 1892, 1897, 1904, 1927, 1930, 1934, 1937, 1965, 1973, 1978, 1991, 2017, 2028, 2033, 2048, 2055, 2058, 2067, 2074, 2094, 2131, 2137, 2146, 2150, 2166, 2170, 2177, 2185, 2191, 2196, 2203, 2208, 2213, 2222, 2255, 2269, 2273, 2288, 2291, 2298, 2305, 2325, 2331, 2334, 2339, 2343, 2347, 2351, 2372, 2380, 2390, 2394, 2416, 2428, 2434, 2462, 2468, 2484, 2497, 2504, 2509, 2511, 2515, 2529, 2543, 2566, 2572, 2584, 2590, 2609, 2617, 2627, 2631, 2644, 2651, 2654, 2661, 2685, 2687, 2693, 2702, 2737, 2749, 2754, 2763, 2775, 2778, 2790, 2792, 2808, 2813, 2816, 2820, 2829, 2835, 2852, 2856, 2872, 2891, 2898, 2905, 2911, 2942, 2945, 2949, 2952, 2989, 3011, 3015, 3031, 3034, 3037]
    else:
        raise NotImplementedError("")

    replace_off = 0
    replace_file = None  # './params_tracking.npz'
#from util.cameradevice import CreativeCameraDevice, FileDevice


__author__ = "Markus Oberweger <*****@*****.**>"
__copyright__ = "Copyright 2015, ICG, Graz University of Technology, Austria"
__credits__ = ["Markus Oberweger"]
__license__ = "GPL"
__version__ = "1.0"
__maintainer__ = "Markus Oberweger"
__email__ = "*****@*****.**"
__status__ = "Development"

if __name__ == '__main__':
    rng = numpy.random.RandomState(23455)

    di = MSRA15Importer('/content/drive/My Drive/KNOWLEDGE ENGINEERING/KE Semester 4/Core Course/CA2 (Matthew)/cvpr15_MSRAHandGestureDB/')
    Seq2 = di.loadSequence('P0')
    testSeqs = [Seq2]

    # di = ICVLImporter('../data/ICVL/')
    # Seq2 = di.loadSequence('test_seq_1')
    # testSeqs = [Seq2]

    #di = NYUImporter('../data/NYU/')
    #Seq2 = di.loadSequence('test_1')
    #testSeqs = [Seq2]

    # load trained network
    poseNetParams = ResNetParams(type=1, nChan=1, wIn=128, hIn=128, batchSize=1, numJoints=14, nDims=3)
    poseNetParams.loadFile = "/content/deep-prior-pp/src/eval/MSRA_network_prior_0.pkl"
    comrefNetParams = ScaleNetParams(type=1, nChan=1, wIn=128, hIn=128, batchSize=1, resizeFactor=2, numJoints=1, nDims=3)
예제 #8
0
import sys
from data.importers import MSRA15Importer
from data.dataset import MSRA15Dataset
# from util.handpose_evaluation import MSRAHandposeEvaluation
from sklearn.decomposition import PCA

if __name__ == '__main__':

    rng = np.random.RandomState(23455)

    print("create data")
    aug_modes = ['com', 'rot', 'none']  # 'sc',

    comref = None  # "./eval/MSRA15_COM_AUGMENT/net_MSRA15_COM_AUGMENT.pkl"
    docom = False
    di = MSRA15Importer('../data/MSRA15/', refineNet=comref)
    seqs = []
    # seqs.append(di.loadSequence('P0', shuffle=True, rng=rng, docom=docom))
    # seqs.append(di.loadSequence('P1', shuffle=True, rng=rng, docom=docom))
    # seqs.append(di.loadSequence('P2', shuffle=True, rng=rng, docom=docom))
    # seqs.append(di.loadSequence('P3', shuffle=True, rng=rng, docom=docom))
    # seqs.append(di.loadSequence('P4', shuffle=True, rng=rng, docom=docom))
    # seqs.append(di.loadSequence('P5', shuffle=True, rng=rng, docom=docom))
    # seqs.append(di.loadSequence('P6', shuffle=True, rng=rng, docom=docom))
    # seqs.append(di.loadSequence('P7', shuffle=True, rng=rng, docom=docom))
    # seqs.append(di.loadSequence('P8', shuffle=True, rng=rng, docom=docom))
    seqs.append(
        di.loadSequence('Retrain',
                        shuffle=True,
                        rng=rng,
                        docom=docom,
def main(argv):
    rng = np.random.RandomState(23455)

    print("create data")
    # aug_modes = ['com', 'rot', 'none']  # 'sc',

    comref = None  # "./eval/MSRA15_COM_AUGMENT/net_MSRA15_COM_AUGMENT.pkl"
    docom = False
    di = MSRA15Importer('../data/MSRA15/', refineNet=comref)

    # testSeqs = [di.loadSequence('Retrain', shuffle=True, rng=rng, docom=docom, cube=(200,200,200))]
    testSeqs = [di.loadSequence('P0', shuffle=True, rng=rng, docom=docom)]

    testDataSet = MSRA15Dataset(testSeqs)
    test_data, test_gt3D = testDataSet.imgStackDepthOnly(testSeqs[0].name)
    test_data_cube = np.asarray([testSeqs[0].config['cube']] * test_data.shape[0])
    test_data = np.transpose(test_data, axes=[0, 2, 3, 1])

    test_label = np.reshape(test_gt3D, [-1, test_gt3D.shape[1] * 3])

    model = get_model()
    model.build_loss(weight_decay=FLAGS.weight_decay,
                     lr=FLAGS.init_lr,
                     lr_decay_rate=FLAGS.lr_decay_rate,
                     lr_decay_step=FLAGS.lr_decay_step)

    epoch_batch_num = test_data.shape[0] // FLAGS.batch_size

    model_path_suffix = '{}_{}_stage{}'.format(FLAGS.model_name, FLAGS.data_set, FLAGS.stages)
    # 测试模型参数加载路径
    model_weights_path = os.path.join(FLAGS.cacheDir,
                                      FLAGS.weightDir,
                                      model_path_suffix,
                                      '{}-{}'.format(FLAGS.model_name, FLAGS.test_iters))

    eval_result_dir = os.path.join(FLAGS.evalDir,
                                   model_path_suffix,
                                   'test_iter{}'.format(FLAGS.test_iters))
    os.makedirs(eval_result_dir, exist_ok=True)

    joint_errors = []
    total_errors = []
    losses = []


    with tf.Session() as sess:
        saver = tf.train.Saver()
        # print(model_weights_path)
        saver.restore(sess, model_weights_path)
        print('load model from: {}'.format(model_weights_path))

        joint_infer = tf.reshape(model.model_output, shape=[-1, model.model_output.shape[1].value//3, 3])* (
                                tf.reshape(model.cube_holder, [-1, 1, 3]) / 2.)
        joint_gt = tf.reshape(model.label_holder, [-1, model.label_holder.shape[1].value // 3, 3]) * (
                                tf.reshape(model.cube_holder, [-1, 1, 3]) / 2.)
        # batch中每个关节点的平均误差
        # joint_error_op = tf.reduce_mean(tf.sqrt(tf.reduce_sum(tf.square(joint_infer - joint_gt), axis=-1)), axis=0)
        joint_error_op = tf.sqrt(tf.reduce_sum(tf.square(joint_infer - joint_gt), axis=-1))

        for i in range(epoch_batch_num):
            batch_x, batch_y, batch_cube = test_data[i*FLAGS.batch_size:(i+1)*FLAGS.batch_size],\
                                           test_label[i*FLAGS.batch_size:(i+1)*FLAGS.batch_size], \
                                           test_data_cube[i*FLAGS.batch_size:(i+1)*FLAGS.batch_size],

            ops = get_test_ops(model)
            ops.append(joint_error_op)
            loss, stage_error, joint_error = sess.run(ops,
                                                      feed_dict={model.input_images: batch_x,
                                                                model.label_holder: batch_y,
                                                                model.cube_holder: batch_cube})

            print('step {}, total loss {}, final stage error {} mm\n'.format(i, loss, stage_error))

            losses.append(loss)
            total_errors.append(stage_error)
            joint_errors.append(joint_error)

        print('average loss {}, average final stage error {} mm \n'.format(np.mean(losses), np.mean(total_errors)))

        # 保存评测误差
        pickle.dump(joint_errors, open('{}/joint_errors.pkl'.format(eval_result_dir), 'wb'))

        import matplotlib
        matplotlib.use('Agg') # 图像输出到文件
        draw_PCK(joint_errors, eval_result_dir)
        draw_joint_error(joint_errors, eval_result_dir)
예제 #10
0
def main(argv):

    rng = np.random.RandomState(23455)

    print("create data")
    aug_modes = ['com', 'rot', 'none']  # 'sc',

    comref = None  # "./eval/MSRA15_COM_AUGMENT/net_MSRA15_COM_AUGMENT.pkl"
    docom = False
    di = MSRA15Importer('../data/MSRA15/', refineNet=comref)
    seqs = []
    seqs.append(di.loadSequence('P0', shuffle=True, rng=rng, docom=docom))
    seqs.append(di.loadSequence('P1', shuffle=True, rng=rng, docom=docom))
    seqs.append(di.loadSequence('P2', shuffle=True, rng=rng, docom=docom))
    seqs.append(di.loadSequence('P3', shuffle=True, rng=rng, docom=docom))
    seqs.append(di.loadSequence('P4', shuffle=True, rng=rng, docom=docom))
    seqs.append(di.loadSequence('P5', shuffle=True, rng=rng, docom=docom))
    seqs.append(di.loadSequence('P6', shuffle=True, rng=rng, docom=docom))
    seqs.append(di.loadSequence('P7', shuffle=True, rng=rng, docom=docom))
    seqs.append(di.loadSequence('P8', shuffle=True, rng=rng, docom=docom))

    testSeqs = [seqs[0]]
    trainSeqs = [seq for seq in seqs if seq not in testSeqs]

    print("training: {}".format(' '.join([s.name for s in trainSeqs])))
    print("testing: {}".format(' '.join([s.name for s in testSeqs])))

    # create training data
    trainDataSet = MSRA15Dataset(trainSeqs, localCache=False)
    nSamp = np.sum([len(s.data) for s in trainSeqs])
    d1, g1 = trainDataSet.imgStackDepthOnly(trainSeqs[0].name)  # 在这里进行归一化处理
    # 存储所有数据([-1,1])
    train_data = np.ones((nSamp, d1.shape[1], d1.shape[2], d1.shape[3]), dtype='float32')
    # 存储所有的标签([-1,1])
    train_gt3D = np.ones((nSamp, g1.shape[1], g1.shape[2]), dtype='float32')
    # 存储所有的立体限制框
    train_data_cube = np.ones((nSamp, 3), dtype='float32')
    # 存储所有的中心点
    train_data_com = np.ones((nSamp, 3), dtype='float32')
    # 存储所有的3D点
    train_gt3Dcrop = np.ones_like(train_gt3D)
    del d1, g1
    gc.collect()
    gc.collect()
    gc.collect()
    oldIdx = 0
    for seq in trainSeqs:
        d, g = trainDataSet.imgStackDepthOnly(seq.name)
        train_data[oldIdx:oldIdx + d.shape[0]] = d
        train_gt3D[oldIdx:oldIdx + d.shape[0]] = g
        train_data_cube[oldIdx:oldIdx + d.shape[0]] = np.asarray([seq.config['cube']] * d.shape[0])
        train_data_com[oldIdx:oldIdx + d.shape[0]] = np.asarray([da.com for da in seq.data])
        train_gt3Dcrop[oldIdx:oldIdx + d.shape[0]] = np.asarray([da.gt3Dcrop for da in seq.data])
        oldIdx += d.shape[0]
        del d, g
        gc.collect()
        gc.collect()
        gc.collect()

    mb = (train_data.nbytes) / (1024 * 1024)
    print("data size: {}Mb".format(mb))

    testDataSet = MSRA15Dataset(testSeqs)
    test_data, test_gt3D = testDataSet.imgStackDepthOnly(testSeqs[0].name)
    test_data_cube = np.asarray([testSeqs[0].config['cube']] * test_data.shape[0])

    # 矩阵转置
    train_data = np.transpose(train_data, axes=[0, 2, 3, 1])
    test_data = np.transpose(test_data, axes=[0, 2, 3, 1])

    train_label = np.reshape(train_gt3D, [-1, train_gt3D.shape[1]*3])
    test_label = np.reshape(test_gt3D, [-1, test_gt3D.shape[1]*3])
    print(train_data.shape)

    # 模型路径的后缀
    model_path_suffix = '{}_{}_stage{}'.format(FLAGS.model_name, FLAGS.data_set, FLAGS.stages)
    # 模型保存路径
    model_save_dir = os.path.join(FLAGS.cacheDir,
                                  FLAGS.weightDir,
                                  model_path_suffix)
    # 训练和测试日志保存路径
    train_log_save_dir = os.path.join(FLAGS.cacheDir,
                                  FLAGS.logDir,
                                  model_path_suffix,
                                      'train')
    test_log_save_dir = os.path.join(FLAGS.cacheDir,
                                  FLAGS.logDir,
                                  model_path_suffix,
                                      'test')

    os.makedirs(model_save_dir,exist_ok=True)
    os.makedirs(train_log_save_dir,exist_ok=True)
    os.makedirs(test_log_save_dir,exist_ok=True)

    # 构建模型
    model = getModel()
    model.build_loss(weight_decay=FLAGS.weight_decay,
                     lr=FLAGS.init_lr,
                     lr_decay_rate=FLAGS.lr_decay_rate,
                     lr_decay_step=FLAGS.lr_decay_step)

    train_data_generator = Data_Generator(FLAGS.batch_size, images=train_data,
                                          labels=train_label, cubes=train_data_cube)
    test_data_generator = Data_Generator(FLAGS.batch_size, images=test_data,
                                          labels=test_label, cubes=test_data_cube)

    print("=====Model Build=====")

    merged_summary = tf.summary.merge_all()
    t1 = time.time()

    with tf.Session() as sess:
        # 创建 tensorboard
        train_writer = tf.summary.FileWriter(train_log_save_dir, sess.graph)
        test_writer = tf.summary.FileWriter(test_log_save_dir, sess.graph)

        # 创建 model saver
        saver = tf.train.Saver(max_to_keep=None)

        # 初始化 all vars
        init_op = tf.global_variables_initializer()
        sess.run(init_op)

        # 加载模型
        if FLAGS.pretrained_model != '':
            saver.restore(sess, os.path.join(model_save_dir, FLAGS.pretrained_model))
            print('load model from {}'.format(os.path.join(model_save_dir, FLAGS.pretrained_model)))
            # 检查权值
            for variable in tf.trainable_variables():
                with tf.variable_scope('', reuse=True):
                    var = tf.get_variable(variable.name.split(':0')[0])
                    print(variable.name, np.mean(sess.run(var)))

        global_step = 0
        while True:
            # 获取一个batch数据
            batch_x, batch_y, batch_cube = train_data_generator.next()

            train_ops = get_train_ops(model)
            # Forward and update weights
            # stage_losses_np, total_loss_np, _, current_lr, \
            # global_step, stage_errores_np, summaries = sess.run([model.stage_loss,
            #                                                      model.total_loss,
            #                                                      model.train_op,
            #                                                      model.cur_lr,
            #                                                      model.global_step,
            #                                                      model.stage_error,
            #                                                      merged_summary],
            #                                                     feed_dict={model.input_images: batch_x,
            #                                                                model.label_holder: batch_y,
            #                                                                model.cube_holder: batch_cube})
            train_ops.append(merged_summary)
            train_step_results = sess.run(train_ops,
                                          feed_dict={model.input_images: batch_x,
                                                     model.label_holder: batch_y,
                                                     model.cube_holder: batch_cube})
            summaries = train_step_results[-1]
            global_step = train_step_results[-2]
            train_writer.add_summary(summaries, global_step)
            # 打印训练中间过程
            if (global_step) % FLAGS.verbose_iters == 0:
                # Show training info
                print_current_training_stats(global_step, train_step_results, time.time() - t1)

            # 验证一下
            if (global_step) % FLAGS.validation_iters == 0:
                test_losses = []
                test_errors = []
                for _ in range(20):
                    batch_x, batch_y, batch_cube = test_data_generator.next()

                    eval_ops = get_eval_ops(model)
                    # test_batch_loss, test_batch_error, \
                    # summaries = sess.run([model.stage_loss[-1], model.stage_error[-1]
                    #                          , merged_summary],
                    #                      feed_dict={model.input_images: batch_x,
                    #                                 model.label_holder: batch_y,
                    #                                 model.cube_holder: batch_cube})
                    eval_ops.append(merged_summary)
                    test_batch_loss, test_batch_error, \
                    summaries = sess.run(eval_ops,
                                         feed_dict={model.input_images: batch_x,
                                                    model.label_holder: batch_y,
                                                    model.cube_holder: batch_cube})

                    test_losses.append(test_batch_loss)
                    test_errors.append(test_batch_error)

                test_mean_loss = np.mean(test_losses)
                test_mean_error = np.mean(test_errors)

                print('\n Validation loss:{}  Validation error:{}mm\n'.format(test_mean_loss, test_mean_error))
                test_writer.add_summary(summaries, global_step)

            # 保存模型
            if (global_step) % FLAGS.model_save_iters == 0:
                saver.save(sess=sess, global_step=global_step,
                           save_path=os.path.join(model_save_dir, FLAGS.model_name))
                print('\nModel checkpoint saved...\n')

            if global_step == FLAGS.training_iters:
                break

        print('Training done.')