Ejemplo n.º 1
0
def test_input_full(config, seqconfig):
    train_data = os.path.join(config.tfrecord_dir, config.train_tfrecords)
    val_data = os.path.join(config.tfrecord_dir, config.test_tfrecords)
    with tf.device('/cpu:0'):
        train_images, train_labels = inputs(
            tfrecord_file=train_data,
            num_epochs=config.epochs,
            image_target_size=config.image_orig_size,
            label_shape=config.num_classes,
            batch_size=config.train_batch,
            data_augment=False)
        val_images, val_labels = inputs(
            tfrecord_file=val_data,
            num_epochs=config.epochs,
            image_target_size=config.image_orig_size,
            label_shape=config.num_classes,
            batch_size=1)
        label_shaped = tf.reshape(
            train_labels, [config.train_batch, config.num_classes / 3, 3])
        error = getMeanError(label_shaped, label_shaped)
        val_label_shaped = tf.reshape(val_labels,
                                      [1, config.num_classes / 3, 3])
    with tf.Session() as sess:
        sess.run(
            tf.group(tf.global_variables_initializer(),
                     tf.local_variables_initializer()))
        step = 0
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            while not coord.should_stop():
                image_np, image_label, train_error = sess.run(
                    [train_images, train_labels, error])
                print step
                val_image_np, val_image_label, val_image_reshaped_label = sess.run(
                    [val_images, val_labels, val_label_shaped])

                if (step > 0):  #and (step <2):

                    # for b in range(config.train_batch):
                    #     im = image_np[b]
                    #     image_com = image_coms[b]
                    #     image_M = image_Ms[b]
                    #
                    #     jts = image_label[b]
                    #     print("shape of jts:{}".format(jts.shape))
                    #     im = im.reshape([128,128])
                    #     check_image_label(im,jts,image_com,image_M,seqconfig['cube'][2] / 2.,allJoints=True,line=False)

                    val_im = val_image_np[0]
                    print("val_im shape:{}".format(val_im.shape))
                    val_jts = val_image_reshaped_label[0]
                    #val_im = val_im.reshape([128, 128])
                    #check_image_label(val_im, val_jts, val_image_com, val_image_M, seqconfig['cube'][2] / 2.,allJoints=True,line=False)
                    show_image(val_im, val_jts)
                step += 1

        except tf.errors.OutOfRangeError:
            print("Done. Epoch limit reached.")
        finally:
            coord.request_stop()
        coord.join(threads)
Ejemplo n.º 2
0
def test_model(config, seqconfig):
    test_data = os.path.join(config.tfrecord_dir, config.test_tfrecords)
    with tf.device('/cpu:0'):
        images, labels, com3Ds, Ms = inputs(
            tfrecord_file=test_data,
            num_epochs=1,
            image_target_size=config.image_target_size,
            label_shape=config.num_classes,
            batch_size=1)
    with tf.device('/gpu:1'):
        with tf.variable_scope("cnn") as scope:
            model = hier_model_struct()
            model.build(images,
                        config.num_classes,
                        13 * 3,
                        13 * 3,
                        13 * 3,
                        13 * 3,
                        12 * 3,
                        train_mode=False)
            labels_shaped = tf.reshape(labels, [config.num_classes / 3, 3]) * \
                                seqconfig['cube'][2] / 2.
            results_shaped = tf.reshape(model.output, [config.num_classes / 3, 3]) * \
                                 seqconfig['cube'][2] / 2.
            error = getMeanError(labels_shaped, results_shaped)

            # Initialize the graph
        gpuconfig = tf.ConfigProto()
        gpuconfig.gpu_options.allow_growth = True
        gpuconfig.allow_soft_placement = True
        saver = tf.train.Saver()

        with tf.Session(config=gpuconfig) as sess:
            sess.run(
                tf.group(tf.global_variables_initializer(),
                         tf.local_variables_initializer()))
            step = 0
            joint_labels = []
            joint_results = []
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            try:
                while not coord.should_stop():
                    checkpoints = tf.train.latest_checkpoint(
                        config.model_output)
                    saver.restore(sess, checkpoints)
                    images_np,labels_np,results_np,images_coms,images_Ms,joint_error,labels_sp,results_sp=sess.run([\
                        images,labels,model.output,com3Ds,Ms,error,labels_shaped,results_shaped])
                    joint_labels.append(labels_sp)
                    joint_results.append(results_sp)
                    print("step={}, test error={} mm".format(
                        step, joint_error))
                    # print(labels_sp)
                    # print(results_sp)

                    if step == 0:
                        sum_error = joint_error
                    else:
                        sum_error = sum_error + joint_error
                    if step % 100 == 0:
                        result_name = "results_com/hier/results/image_{}.png".format(
                            step)
                        save_result_image(images_np, images_coms, images_Ms,
                                          labels_sp, results_sp,
                                          seqconfig['cube'][2] / 2.,
                                          result_name)
                    if joint_error > 40:
                        result_name = "results_com/hier/bad/image_{}.png".format(
                            step)
                        save_result_image(images_np, images_coms, images_Ms,
                                          labels_sp, results_sp,
                                          seqconfig['cube'][2] / 2.,
                                          result_name)
                    step += 1
            except tf.errors.OutOfRangeError:
                print("Done.Epoch limit reached.")
            finally:
                coord.request_stop()
            coord.join(threads)
            print("load model from {}".format(checkpoints))
            print("testing mean error is {}mm".format(sum_error / step))

            pickleCache = 'results_com/hier/cnn_result_cache.pkl'
            print("Save cache data to {}".format(pickleCache))
            f = open(pickleCache, 'wb')
            cPickle.dump((joint_labels, joint_results),
                         f,
                         protocol=cPickle.HIGHEST_PROTOCOL)
            f.close()
            np_labels = np.asarray(joint_labels)
            np_results = np.asarray(joint_results)
            np_mean = getMeanError_np(np_labels, np_results)
            print np_mean
Ejemplo n.º 3
0
def train_model(config, seqconfig):
    train_data = os.path.join(config.tfrecord_dir, config.train_tfrecords)
    val_data = os.path.join(config.tfrecord_dir, config.val_tfrecords)
    with tf.device('/cpu:0'):
        train_images, train_labels, com3Ds, Ms = inputs(
            tfrecord_file=train_data,
            num_epochs=config.epochs,
            image_target_size=config.image_target_size,
            label_shape=config.num_classes,
            batch_size=config.train_batch)
        val_images, val_labels, val_com3Ds, val_Ms = inputs(
            tfrecord_file=val_data,
            num_epochs=config.epochs,
            image_target_size=config.image_target_size,
            label_shape=config.num_classes,
            batch_size=config.val_batch)
        label_shaped = tf.reshape(
            train_labels, [config.train_batch, config.num_classes / 3, 3])
        split_lable = tf.split(label_shaped, 36, axis=1)
        P_label_shaped = tf.concat([
            split_lable[0], split_lable[1], split_lable[2], split_lable[3],
            split_lable[4], split_lable[5], split_lable[29], split_lable[30],
            split_lable[31], split_lable[32], split_lable[33], split_lable[34],
            split_lable[35]
        ],
                                   axis=1)
        R_label_shaped = tf.concat([
            split_lable[6], split_lable[7], split_lable[8], split_lable[9],
            split_lable[10], split_lable[11], split_lable[29], split_lable[30],
            split_lable[31], split_lable[32], split_lable[33], split_lable[34],
            split_lable[35]
        ],
                                   axis=1)
        M_label_shaped = tf.concat([
            split_lable[12], split_lable[13], split_lable[14], split_lable[15],
            split_lable[16], split_lable[17], split_lable[29], split_lable[30],
            split_lable[31], split_lable[32], split_lable[33], split_lable[34],
            split_lable[35]
        ],
                                   axis=1)
        I_label_shaped = tf.concat([
            split_lable[18], split_lable[19], split_lable[20], split_lable[21],
            split_lable[22], split_lable[23], split_lable[29], split_lable[30],
            split_lable[31], split_lable[32], split_lable[33], split_lable[34],
            split_lable[35]
        ],
                                   axis=1)
        T_label_shaped = tf.concat([
            split_lable[24], split_lable[25], split_lable[26], split_lable[27],
            split_lable[28], split_lable[29], split_lable[30], split_lable[31],
            split_lable[32], split_lable[33], split_lable[34], split_lable[35]
        ],
                                   axis=1)
        P_label = tf.reshape(
            P_label_shaped,
            [config.train_batch,
             P_label_shaped.get_shape().as_list()[1] * 3])
        R_label = tf.reshape(
            R_label_shaped,
            [config.train_batch,
             R_label_shaped.get_shape().as_list()[1] * 3])
        M_label = tf.reshape(
            M_label_shaped,
            [config.train_batch,
             M_label_shaped.get_shape().as_list()[1] * 3])
        I_label = tf.reshape(
            I_label_shaped,
            [config.train_batch,
             I_label_shaped.get_shape().as_list()[1] * 3])
        T_label = tf.reshape(
            T_label_shaped,
            [config.train_batch,
             T_label_shaped.get_shape().as_list()[1] * 3])
        error = getMeanError(label_shaped, label_shaped)
        val_label_shaped = tf.reshape(
            val_labels, [config.val_batch, config.num_classes / 3, 3])
        val_split_lable = tf.split(val_label_shaped, 36, axis=1)
        val_P_label_shaped = tf.concat([
            val_split_lable[0], val_split_lable[1], val_split_lable[2],
            val_split_lable[3], val_split_lable[4], val_split_lable[5],
            val_split_lable[29], val_split_lable[30], val_split_lable[31],
            val_split_lable[32], val_split_lable[33], val_split_lable[34],
            val_split_lable[35]
        ],
                                       axis=1)
        val_R_label_shaped = tf.concat([
            val_split_lable[6], val_split_lable[7], val_split_lable[8],
            val_split_lable[9], val_split_lable[10], val_split_lable[11],
            val_split_lable[29], val_split_lable[30], val_split_lable[31],
            val_split_lable[32], val_split_lable[33], val_split_lable[34],
            val_split_lable[35]
        ],
                                       axis=1)
        val_M_label_shaped = tf.concat([
            val_split_lable[12], val_split_lable[13], val_split_lable[14],
            val_split_lable[15], val_split_lable[16], val_split_lable[17],
            val_split_lable[29], val_split_lable[30], val_split_lable[31],
            val_split_lable[32], val_split_lable[33], val_split_lable[34],
            val_split_lable[35]
        ],
                                       axis=1)
        val_I_label_shaped = tf.concat([
            val_split_lable[18], val_split_lable[19], val_split_lable[20],
            val_split_lable[21], val_split_lable[22], val_split_lable[23],
            val_split_lable[29], val_split_lable[30], val_split_lable[31],
            val_split_lable[32], val_split_lable[33], val_split_lable[34],
            val_split_lable[35]
        ],
                                       axis=1)
        val_T_label_shaped = tf.concat([
            val_split_lable[24], val_split_lable[25], val_split_lable[26],
            val_split_lable[27], val_split_lable[28], val_split_lable[29],
            val_split_lable[30], val_split_lable[31], val_split_lable[32],
            val_split_lable[33], val_split_lable[34], val_split_lable[35]
        ],
                                       axis=1)
        val_P_label = tf.reshape(val_P_label_shaped, [
            config.val_batch,
            val_P_label_shaped.get_shape().as_list()[1] * 3
        ])
        val_R_label = tf.reshape(val_R_label_shaped, [
            config.val_batch,
            val_R_label_shaped.get_shape().as_list()[1] * 3
        ])
        val_M_label = tf.reshape(val_M_label_shaped, [
            config.val_batch,
            val_M_label_shaped.get_shape().as_list()[1] * 3
        ])
        val_I_label = tf.reshape(val_I_label_shaped, [
            config.val_batch,
            val_I_label_shaped.get_shape().as_list()[1] * 3
        ])
        val_T_label = tf.reshape(val_T_label_shaped, [
            config.val_batch,
            val_T_label_shaped.get_shape().as_list()[1] * 3
        ])
    with tf.device('/gpu:0'):
        with tf.variable_scope("cnn") as scope:
            print("create training graph:")
            model = hier_model_struct()
            model.build(train_images,config.num_classes,P_label.get_shape().as_list()[1],R_label.get_shape().as_list()[1],M_label.get_shape().as_list()[1], \
                        I_label.get_shape().as_list()[1],T_label.get_shape().as_list()[1],train_mode=True)
            hand_loss = tf.nn.l2_loss(model.output - train_labels)
            p_loss = tf.nn.l2_loss(model.p_output - P_label)
            r_loss = tf.nn.l2_loss(model.r_output - R_label)
            m_loss = tf.nn.l2_loss(model.m_output - M_label)
            i_loss = tf.nn.l2_loss(model.i_output - I_label)
            t_loss = tf.nn.l2_loss(model.t_output - T_label)
            loss = hand_loss + p_loss + r_loss + m_loss + i_loss + t_loss
            if config.wd_penalty is None:
                train_op = tf.train.AdamOptimizer(1e-5).minimize(loss)
            else:
                wd_l = [
                    v for v in tf.get_collection(
                        tf.GraphKeys.TRAINABLE_VARIABLES)
                    if 'biases' not in v.name
                ]
                loss_wd = loss + (config.wd_penalty *
                                  tf.add_n([tf.nn.l2_loss(x) for x in wd_l]))
                train_op = tf.train.AdamOptimizer(1e-5).minimize(loss_wd)

            train_results_shaped = tf.reshape(
                model.output, [config.train_batch, config.num_classes / 3, 3])
            train_error = getMeanError_train(
                label_shaped, train_results_shaped) * seqconfig['cube'][2] / 2.
            train_P_results_shaped = tf.reshape(
                model.p_output,
                [config.train_batch,
                 P_label.get_shape().as_list()[1] / 3, 3])
            train_P_error = getMeanError_train(
                P_label_shaped,
                train_P_results_shaped) * seqconfig['cube'][2] / 2.
            train_R_results_shaped = tf.reshape(
                model.r_output,
                [config.train_batch,
                 R_label.get_shape().as_list()[1] / 3, 3])
            train_R_error = getMeanError_train(
                R_label_shaped,
                train_R_results_shaped) * seqconfig['cube'][2] / 2.
            train_M_results_shaped = tf.reshape(
                model.m_output,
                [config.train_batch,
                 M_label.get_shape().as_list()[1] / 3, 3])
            train_M_error = getMeanError_train(
                M_label_shaped,
                train_M_results_shaped) * seqconfig['cube'][2] / 2.
            train_I_results_shaped = tf.reshape(
                model.i_output,
                [config.train_batch,
                 I_label.get_shape().as_list()[1] / 3, 3])
            train_I_error = getMeanError_train(
                I_label_shaped,
                train_I_results_shaped) * seqconfig['cube'][2] / 2.
            train_T_results_shaped = tf.reshape(
                model.t_output,
                [config.train_batch,
                 T_label.get_shape().as_list()[1] / 3, 3])
            train_T_error = getMeanError_train(
                T_label_shaped,
                train_T_results_shaped) * seqconfig['cube'][2] / 2.

            print("using validation")
            scope.reuse_variables()
            val_model = hier_model_struct()
            val_model.build(val_images,config.num_classes,val_P_label.get_shape().as_list()[1],val_R_label.get_shape().as_list()[1],val_M_label.get_shape().as_list()[1], \
                        val_I_label.get_shape().as_list()[1],val_T_label.get_shape().as_list()[1],train_mode=False)

            val_results_shaped = tf.reshape(
                val_model.output,
                [config.val_batch, config.num_classes / 3, 3])
            val_error = getMeanError_train(
                val_label_shaped,
                val_results_shaped) * seqconfig['cube'][2] / 2.
            val_P_results_shaped = tf.reshape(val_model.p_output, [
                config.val_batch,
                val_P_label.get_shape().as_list()[1] / 3, 3
            ])
            val_P_error = getMeanError_train(
                val_P_label_shaped,
                val_P_results_shaped) * seqconfig['cube'][2] / 2.
            val_R_results_shaped = tf.reshape(val_model.r_output, [
                config.val_batch,
                val_R_label.get_shape().as_list()[1] / 3, 3
            ])
            val_R_error = getMeanError_train(
                val_R_label_shaped,
                val_R_results_shaped) * seqconfig['cube'][2] / 2.
            val_M_results_shaped = tf.reshape(val_model.m_output, [
                config.val_batch,
                val_M_label.get_shape().as_list()[1] / 3, 3
            ])
            val_M_error = getMeanError_train(
                val_M_label_shaped,
                val_M_results_shaped) * seqconfig['cube'][2] / 2.
            val_I_results_shaped = tf.reshape(val_model.i_output, [
                config.val_batch,
                val_I_label.get_shape().as_list()[1] / 3, 3
            ])
            val_I_error = getMeanError_train(
                val_I_label_shaped,
                val_I_results_shaped) * seqconfig['cube'][2] / 2.
            val_T_results_shaped = tf.reshape(val_model.t_output, [
                config.val_batch,
                val_T_label.get_shape().as_list()[1] / 3, 3
            ])
            val_T_error = getMeanError_train(
                val_T_label_shaped,
                val_T_results_shaped) * seqconfig['cube'][2] / 2.

            tf.summary.scalar("loss", loss)
            tf.summary.scalar("p_loss", p_loss)
            tf.summary.scalar("r_loss", r_loss)
            tf.summary.scalar("m_loss", m_loss)
            tf.summary.scalar("i_loss", i_loss)
            tf.summary.scalar("t_loss", t_loss)

            if config.wd_penalty is not None:
                tf.summary.scalar("loss_wd", loss_wd)
            tf.summary.scalar("train error", train_error)
            tf.summary.scalar("train P error", train_P_error)
            tf.summary.scalar("train R error", train_R_error)
            tf.summary.scalar("train M error", train_M_error)
            tf.summary.scalar("train I error", train_I_error)
            tf.summary.scalar("train T error", train_T_error)

            tf.summary.scalar("validation error", val_error)
            tf.summary.scalar("validation P error", val_P_error)
            tf.summary.scalar("validation R error", val_R_error)
            tf.summary.scalar("validation M error", val_M_error)
            tf.summary.scalar("validation I error", val_I_error)
            tf.summary.scalar("validation T error", val_T_error)

            summary_op = tf.summary.merge_all()
        saver = tf.train.Saver(tf.global_variables())

    # Initialize the graph
    gpuconfig = tf.ConfigProto()
    gpuconfig.gpu_options.allow_growth = True
    gpuconfig.allow_soft_placement = True
    first_v = True
    with tf.Session(config=gpuconfig) as sess:
        summary_writer = tf.summary.FileWriter(config.train_summaries,
                                               sess.graph)
        sess.run(
            tf.group(tf.global_variables_initializer(),
                     tf.local_variables_initializer()))
        step = 0
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            while not coord.should_stop():
                _,image_np,image_label,image_coms,image_Ms,tr_error,tr_loss,tr_loss_wd,tr_P_error,tr_R_error,tr_M_error,\
                    tr_I_error,tr_T_error,tr_P_loss,tr_R_loss,tr_M_loss,tr_I_loss,tr_T_loss = \
                    sess.run([train_op,train_images,train_labels,com3Ds,Ms,train_error,loss,loss_wd,train_P_error,\
                              train_R_error,train_M_error,train_I_error,train_T_error,p_loss,r_loss,m_loss,i_loss,t_loss])
                print("step={},loss={},losswd={},ploss={},rloss={},,mloss={},iloss={},tloss={},error={} mm,perror={} mm,rerror={} mm,merror={} mm,ierror={} mm,terror={} mm"\
                      .format(step,tr_loss,tr_loss_wd,tr_P_loss,tr_R_loss,tr_M_loss,tr_I_loss,tr_T_loss,tr_error,tr_P_error,tr_R_error,
                              tr_M_error,tr_I_error,tr_T_error))

                if step % 200 == 0:
                    val_image_np, val_image_label, val_image_coms, val_image_Ms, v_error, v_P_error, v_R_error, v_M_error, v_I_error, v_T_error = sess.run(
                        [
                            val_images, val_labels, val_com3Ds, val_Ms,
                            val_error, val_P_error, val_R_error, val_M_error,
                            val_I_error, val_T_error
                        ])
                    print("     val_error={} mm, val_P_error={} mm, val_R_error={} mm, val_M_error={} mm, val_I_error={} mm, val_T_error={} mm"\
                          .format(v_error,v_P_error,v_R_error,v_M_error,v_I_error,v_T_error))

                    summary_str = sess.run(summary_op)
                    summary_writer.add_summary(summary_str, step)
                    # save the model checkpoint if it's the best yet
                    if first_v is True:
                        val_min = v_error
                        first_v = False
                    else:
                        if v_error < val_min:
                            saver.save(sess,
                                       os.path.join(
                                           config.model_output,
                                           'hier_model' + str(step) + '.ckpt'),
                                       global_step=step)
                            # store the new max validation accuracy
                            val_min = v_error
                # if (step > 0) and (step <2):
                #
                #     for b in range(config.train_batch):
                #         im = image_np[b]
                #         print("im shape:{}".format(im.shape))
                #         image_com = image_coms[b]
                #         image_M = image_Ms[b]
                #         #print("shape of im:{}".format(im.shape))
                #         jts = image_label[b]
                #         im = im.reshape([128,128])
                #         check_image_label(im,jts,image_com,image_M,seqconfig['cube'][2] / 2.,allJoints=True,line=True)
                #
                #     val_im = val_image_np[0]
                #     print("val_im shape:{}".format(val_im.shape))
                #     val_image_com = val_image_coms[0]
                #     val_image_M = val_image_Ms[0]
                #     # print("shape of im:{}".format(im.shape))
                #     val_jts = val_image_label[0]
                #     val_im = val_im.reshape([128, 128])
                #     check_image_label(val_im, val_jts, val_image_com, val_image_M, seqconfig['cube'][2] / 2.,allJoints=True,line=True)

                step += 1

        except tf.errors.OutOfRangeError:
            print("Done. Epoch limit reached.")
        finally:
            coord.request_stop()
        coord.join(threads)
Ejemplo n.º 4
0
def test_input_tree(config,seqconfig):
    train_data = os.path.join(config.tfrecord_dir, config.train_tfrecords)
    val_data = os.path.join(config.tfrecord_dir, config.val_tfrecords)
    with tf.device('/cpu:0'):
        train_images,train_labels,com3Ds,Ms = inputs(tfrecord_file = train_data,
                                           num_epochs=config.epochs,
                                           image_target_size = config.image_target_size,
                                           label_shape=config.num_classes,
                                           batch_size =config.train_batch)
        val_images, val_labels, val_com3Ds, val_Ms = inputs(tfrecord_file=val_data,
                                                            num_epochs=config.epochs,
                                                            image_target_size=config.image_target_size,
                                                            label_shape=config.num_classes,
                                                            batch_size=1)
        label_shaped = tf.reshape(train_labels,[config.train_batch,config.num_classes/3,3])
        split_lable=tf.split(label_shaped,36,axis=1)
        P_label_shaped = tf.concat(
            [split_lable[0], split_lable[1], split_lable[2], split_lable[3], split_lable[4], split_lable[5],
             split_lable[29], split_lable[30], split_lable[31], split_lable[32], split_lable[33], split_lable[34],
             split_lable[35]], axis=1)
        R_label_shaped = tf.concat(
            [split_lable[6], split_lable[7], split_lable[8], split_lable[9], split_lable[10], split_lable[11],
             split_lable[29], split_lable[30], split_lable[31], split_lable[32], split_lable[33], split_lable[34],
             split_lable[35]], axis=1)
        M_label_shaped = tf.concat(
            [split_lable[12], split_lable[13], split_lable[14], split_lable[15], split_lable[16], split_lable[17],
             split_lable[29], split_lable[30], split_lable[31], split_lable[32], split_lable[33], split_lable[34],
             split_lable[35]], axis=1)
        I_label_shaped = tf.concat(
            [split_lable[18], split_lable[19], split_lable[20], split_lable[21], split_lable[22], split_lable[23],
             split_lable[29], split_lable[30], split_lable[31], split_lable[32], split_lable[33], split_lable[34],
             split_lable[35]], axis=1)
        T_label_shaped = tf.concat(
            [split_lable[24], split_lable[25], split_lable[26], split_lable[27], split_lable[28],
             split_lable[29], split_lable[30], split_lable[31], split_lable[32], split_lable[33], split_lable[34],
             split_lable[35]], axis=1)
        P_label = tf.reshape(P_label_shaped, [config.train_batch, P_label_shaped.get_shape().as_list()[1] * 3])
        R_label = tf.reshape(R_label_shaped, [config.train_batch, R_label_shaped.get_shape().as_list()[1] * 3])
        M_label = tf.reshape(M_label_shaped, [config.train_batch, M_label_shaped.get_shape().as_list()[1] * 3])
        I_label = tf.reshape(I_label_shaped, [config.train_batch, I_label_shaped.get_shape().as_list()[1] * 3])
        T_label = tf.reshape(T_label_shaped, [config.train_batch, T_label_shaped.get_shape().as_list()[1] * 3])
        error = getMeanError(label_shaped,label_shaped)
        val_label_shaped = tf.reshape(val_labels, [1, config.num_classes/3, 3])
        val_split_lable = tf.split(val_label_shaped, 36, axis=1)
        val_P_label_shaped = tf.concat(
            [val_split_lable[0], val_split_lable[1], val_split_lable[2], val_split_lable[3], val_split_lable[4],
             val_split_lable[5],
             val_split_lable[29], val_split_lable[30], val_split_lable[31], val_split_lable[32], val_split_lable[33],
             val_split_lable[34],
             val_split_lable[35]], axis=1)
        val_R_label_shaped = tf.concat(
            [val_split_lable[6], val_split_lable[7], val_split_lable[8], val_split_lable[9], val_split_lable[10],
             val_split_lable[11],
             val_split_lable[29], val_split_lable[30], val_split_lable[31], val_split_lable[32], val_split_lable[33],
             val_split_lable[34],
             val_split_lable[35]], axis=1)
        val_M_label_shaped = tf.concat(
            [val_split_lable[12], val_split_lable[13], val_split_lable[14], val_split_lable[15], val_split_lable[16],
             val_split_lable[17],
             val_split_lable[29], val_split_lable[30], val_split_lable[31], val_split_lable[32], val_split_lable[33],
             val_split_lable[34],
             val_split_lable[35]], axis=1)
        val_I_label_shaped = tf.concat(
            [val_split_lable[18], val_split_lable[19], val_split_lable[20], val_split_lable[21], val_split_lable[22],
             val_split_lable[23],
             val_split_lable[29], val_split_lable[30], val_split_lable[31], val_split_lable[32], val_split_lable[33],
             val_split_lable[34],
             val_split_lable[35]], axis=1)
        val_T_label_shaped = tf.concat(
            [val_split_lable[24], val_split_lable[25], val_split_lable[26], val_split_lable[27], val_split_lable[28],
             val_split_lable[29], val_split_lable[30], val_split_lable[31], val_split_lable[32], val_split_lable[33],
             val_split_lable[34],
             val_split_lable[35]], axis=1)
        val_P_label = tf.reshape(val_P_label_shaped, [config.val_batch, val_P_label_shaped.get_shape().as_list()[1] * 3])
        val_R_label = tf.reshape(val_R_label_shaped, [config.val_batch, val_R_label_shaped.get_shape().as_list()[1] * 3])
        val_M_label = tf.reshape(val_M_label_shaped, [config.val_batch, val_M_label_shaped.get_shape().as_list()[1] * 3])
        val_I_label = tf.reshape(val_I_label_shaped, [config.val_batch, val_I_label_shaped.get_shape().as_list()[1] * 3])
        val_T_label = tf.reshape(val_T_label_shaped, [config.val_batch, val_T_label_shaped.get_shape().as_list()[1] * 3])
    with tf.Session() as sess:
        sess.run(tf.group(tf.global_variables_initializer(),
                          tf.local_variables_initializer()))
        step =0
        coord=tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            while not coord.should_stop():
                image_np,image_label,image_coms,image_Ms,train_error,P,R,M,I,T = sess.run([train_images,train_labels,com3Ds,Ms,error,P_label,R_label,\
                                                                                 M_label,I_label,T_label])
                print step
                print image_np.shape
                print train_error

                val_image_np, val_image_label, val_image_coms, val_image_Ms,val_P,val_R,val_M,val_I,val_T= sess.run(
                    [val_images, val_labels, val_com3Ds, val_Ms,val_P_label,val_R_label,val_M_label,val_I_label,val_T_label])

                #image = tf.split(image_np, 3, 3)[0]
                #print image.shape
                #print image_label.shape

                if (step > 0) and (step <2):

                    for b in range(config.train_batch):
                        im = image_np[b]
                        print("im shape:{}".format(im.shape))
                        image_com = image_coms[b]
                        image_M = image_Ms[b]
                        #print("shape of im:{}".format(im.shape))
                        jts = image_label[b]
                        im = im.reshape([128,128])
                        check_image_label(im,jts,image_com,image_M,seqconfig['cube'][2] / 2.,allJoints=True,line=True)
                        check_image_label(im,P[b],image_com,image_M,seqconfig['cube'][2] / 2.,line=False)
                        check_image_label(im, R[b], image_com, image_M, seqconfig['cube'][2] / 2., line=False)
                        check_image_label(im, M[b], image_com, image_M, seqconfig['cube'][2] / 2., line=False)
                        check_image_label(im, I[b], image_com, image_M, seqconfig['cube'][2] / 2., line=False)
                        check_image_label(im, T[b], image_com, image_M, seqconfig['cube'][2] / 2., line=False)

                    val_im = val_image_np[0]
                    print("val_im shape:{}".format(val_im.shape))
                    val_image_com = val_image_coms[0]
                    val_image_M = val_image_Ms[0]
                    # print("shape of im:{}".format(im.shape))
                    val_jts = val_image_label[0]
                    val_im = val_im.reshape([128, 128])
                    check_image_label(val_im, val_jts, val_image_com, val_image_M, seqconfig['cube'][2] / 2.,allJoints=True,line=True)
                    check_image_label(val_im, val_P[0], val_image_com, val_image_M, seqconfig['cube'][2] / 2.,line=False)
                    check_image_label(val_im, val_R[0], val_image_com, val_image_M, seqconfig['cube'][2] / 2.,
                                      line=False)
                    check_image_label(val_im, val_M[0], val_image_com, val_image_M, seqconfig['cube'][2] / 2.,
                                      line=False)
                    check_image_label(val_im, val_I[0], val_image_com, val_image_M, seqconfig['cube'][2] / 2.,
                                      line=False)
                    check_image_label(val_im, val_T[0], val_image_com, val_image_M, seqconfig['cube'][2] / 2.,
                                      line=False)
                step += 1

        except tf.errors.OutOfRangeError:
            print("Done. Epoch limit reached.")
        finally:
            coord.request_stop()
        coord.join(threads)