def main(flags):

    current_time = time.strftime("%m/%d/%H/%M/%S")
    train_logdir = os.path.join(flags.logdir, "train", current_time)
    validation_logdir = os.path.join(flags.logdir, "validation",
                                     current_time)  #

    train = pd.read_csv(flags.training_dir)
    num_train = train.shape[0]

    validation = pd.read_csv(flags.validation_dir)
    num_validation = validation.shape[0]

    tf.reset_default_graph()
    #set the placeholder for image and ground truth
    X = tf.placeholder(
        tf.float32,
        shape=[flags.batch_size, flags.h, flags.w, flags.c_image],
        name='X')
    y = tf.placeholder(
        tf.float32,
        shape=[flags.batch_size, flags.h, flags.w, flags.c_label],
        name='y')
    #set the placeholder for training mode
    training = tf.placeholder(tf.bool, name='training')
    #get the output of the network
    score_dsn5_up, score_dsn4_up, score_dsn3_up, score_dsn2_up, score_dsn1_up, upscore_fuse = model.unet(
        X, flags.batch_size, flags.h, flags.w, training=True)
    print(upscore_fuse.get_shape().as_list())

    #the cross_entropy loss
    loss5 = loss_CE(score_dsn5_up, y)
    loss4 = loss_CE(score_dsn4_up, y)
    loss3 = loss_CE(score_dsn3_up, y)
    loss2 = loss_CE(score_dsn2_up, y)
    loss1 = loss_CE(score_dsn1_up, y)
    loss_fuse = loss_CE(upscore_fuse, y)
    #add all of the output to tensorboard scalar for  visualization
    tf.summary.scalar("CE5", loss5)
    tf.summary.scalar("CE4", loss4)
    tf.summary.scalar("CE3", loss3)
    tf.summary.scalar("CE2", loss2)
    tf.summary.scalar("CE1", loss1)
    tf.summary.scalar("CE_fuse", loss_fuse)

    global_step = tf.Variable(0,
                              dtype=tf.int64,
                              trainable=False,
                              name='global_step')

    #sets the decay rate of aaf loss
    dec = tf.pow(10.0, (tf.cast(
        -(global_step / int(num_train / flags.batch_size * flags.epochs)),
        tf.float32)))

    w_edge = tf.get_variable(name='edge_w',
                             shape=(1, 1, 1, 2, 1, 3),
                             dtype=tf.float32,
                             initializer=tf.constant_initializer(0))
    w_edge = tf.nn.softmax(w_edge, dim=-1)
    w_not_edge = tf.get_variable(name='nonedge_w',
                                 shape=(1, 1, 1, 2, 1, 3),
                                 dtype=tf.float32,
                                 initializer=tf.constant_initializer(0))
    w_not_edge = tf.nn.softmax(w_not_edge, dim=-1)

    score_dsn5_up = tf.nn.sigmoid(score_dsn5_up)
    score_dsn4_up = tf.nn.sigmoid(score_dsn4_up)
    score_dsn3_up = tf.nn.sigmoid(score_dsn3_up)
    score_dsn2_up = tf.nn.sigmoid(score_dsn2_up)
    score_dsn1_up = tf.nn.sigmoid(score_dsn1_up)
    upscore_fuse = tf.nn.sigmoid(upscore_fuse, name='output')

    upscore_fuse_0 = 1 - upscore_fuse
    prob = tf.concat([upscore_fuse_0, upscore_fuse], axis=-1)
    # aaf_loss = loss_aaf(y, upscore_fuse)

    labels = tf.cast(y, tf.uint8)
    one_hot_lab = tf.one_hot(tf.squeeze(labels, axis=-1), depth=2)
    aaf_losses = []
    eloss_1, neloss_1 = lossx.adaptive_affinity_loss(labels, one_hot_lab, prob,
                                                     1, 2, 3, w_edge[..., 0],
                                                     w_not_edge[..., 0])
    # Apply AAF on 5x5 patch.
    eloss_2, neloss_2 = lossx.adaptive_affinity_loss(labels, one_hot_lab, prob,
                                                     2, 2, 3, w_edge[..., 1],
                                                     w_not_edge[..., 1])
    # Apply AAF on 7x7 patch.
    eloss_3, neloss_3 = lossx.adaptive_affinity_loss(labels, one_hot_lab, prob,
                                                     3, 2, 3, w_edge[..., 2],
                                                     w_not_edge[..., 2])
    #decays aaf loss with the increase of global step
    aaf_loss = tf.reduce_mean(eloss_1) * dec
    aaf_loss += tf.reduce_mean(eloss_2) * dec
    aaf_loss += tf.reduce_mean(eloss_3) * dec
    aaf_loss += tf.reduce_mean(neloss_1) * dec
    aaf_loss += tf.reduce_mean(neloss_2) * dec
    aaf_loss += tf.reduce_mean(neloss_3) * dec
    aaf_losses.append(aaf_loss)

    # Sum all loss terms.
    mean_seg_loss = loss5 + loss4 + loss3 + loss2 + loss1 + loss_fuse
    mean_aaf_loss = tf.add_n(aaf_losses)
    CE_total = mean_seg_loss + mean_aaf_loss

    tf.summary.scalar("CE_total", CE_total)
    tf.summary.scalar("mean_seg_loss", mean_seg_loss)
    tf.summary.scalar("mean_aaf_loss", mean_aaf_loss)
    tf.summary.scalar("dec", dec)

    # Grab variable names which are used for training
    all_trainable = tf.trainable_variables()
    fc_trainable = [
        v for v in all_trainable
        if 'block' not in v.name and 'edge' not in v.name
    ]  # lr*1
    base_trainable = [v for v in all_trainable if 'block' in v.name]  # lr*10
    aaf_trainable = [v for v in all_trainable if 'edge' in v.name]

    # Computes gradients per iteration.
    grads = tf.gradients(CE_total,
                         base_trainable + fc_trainable + aaf_trainable)
    grads_base = grads[0:len(base_trainable)]
    grads_fc = grads[len(base_trainable):len(base_trainable) +
                     len(fc_trainable)]
    grads_aaf = grads[len(base_trainable) + len(fc_trainable):]
    grads_aaf = [-g for g in grads_aaf]

    learning_rate = tf.train.exponential_decay(flags.learning_rate,
                                               global_step,
                                               decay_steps=flags.decay_step,
                                               decay_rate=flags.decay_rate,
                                               staircase=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        opt_base = tf.train.AdamOptimizer(10 * learning_rate)
        opt_fc = tf.train.AdamOptimizer(learning_rate)
        opt_aaf = tf.train.AdamOptimizer(learning_rate)
        global_step = tf.train.get_or_create_global_step()

        # Define tensorflow operations which apply gradients to update variables.
        train_op_base = opt_base.apply_gradients(
            zip(grads_base, base_trainable))
        train_op_fc = opt_fc.apply_gradients(zip(grads_fc, fc_trainable))
        train_op_aaf = opt_aaf.apply_gradients(zip(grads_aaf, aaf_trainable),
                                               global_step=global_step)
        train_op = tf.group(train_op_base, train_op_fc, train_op_aaf)

    train_csv = tf.train.string_input_producer(['train.csv'])
    validation_csv = tf.train.string_input_producer(['validation.csv'])
    #get the training and validation data
    train_image, train_label = read_csv(train_csv, augmentation=True)
    validation_image, validation_label = read_csv(validation_csv,
                                                  augmentation=False)

    X_train_batch_op, y_train_batch_op = tf.train.shuffle_batch(
        [train_image, train_label],
        batch_size=flags.batch_size,
        capacity=flags.batch_size * 500,
        min_after_dequeue=flags.batch_size * 100,
        allow_smaller_final_batch=True)

    X_validation_batch_op, y_validation_batch_op = tf.train.batch(
        [validation_image, validation_label],
        batch_size=flags.batch_size,
        capacity=flags.batch_size * 20,
        allow_smaller_final_batch=True)

    print('Shuffle batch done')
    #add all of the output into collection
    tf.add_to_collection('inputs', X)
    tf.add_to_collection('inputs', training)
    tf.add_to_collection('score_dsn5_up', score_dsn5_up)
    tf.add_to_collection('score_dsn4_up', score_dsn4_up)
    tf.add_to_collection('score_dsn3_up', score_dsn3_up)
    tf.add_to_collection('score_dsn2_up', score_dsn2_up)
    tf.add_to_collection('score_dsn1_up', score_dsn1_up)
    tf.add_to_collection('upscore_fuse', upscore_fuse)

    #add all of the output to tensorboard image for visualization
    tf.summary.image('Input Image:', X)
    tf.summary.image('Label:', y)
    tf.summary.image('score_dsn5_up:', score_dsn5_up)
    tf.summary.image('score_dsn4_up:', score_dsn4_up)
    tf.summary.image('score_dsn3_up:', score_dsn3_up)
    tf.summary.image('score_dsn2_up:', score_dsn2_up)
    tf.summary.image('score_dsn1_up:', score_dsn1_up)
    tf.summary.image('upscore_fuse:', upscore_fuse)

    #add the learning rate into tensorboard scalar for visualization
    tf.summary.scalar("learning_rate", learning_rate)

    #add all of the output to tensorboard histogram for visualization
    tf.summary.histogram('score_dsn1_up:', score_dsn1_up)
    tf.summary.histogram('score_dsn2_up:', score_dsn2_up)
    tf.summary.histogram('score_dsn3_up:', score_dsn3_up)
    tf.summary.histogram('score_dsn4_up:', score_dsn4_up)
    tf.summary.histogram('score_dsn5_up:', score_dsn5_up)
    tf.summary.histogram('upscore_fuse:', upscore_fuse)
    summary_op = tf.summary.merge_all()

    with tf.Session() as sess:
        train_writer = tf.summary.FileWriter(train_logdir, sess.graph)
        validation_writer = tf.summary.FileWriter(validation_logdir)
        init = tf.global_variables_initializer()
        sess.run(init)
        saver = tf.train.Saver()
        try:
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            for epoch in range(flags.epochs):
                #feed the network with training data
                for step in range(0, num_train, flags.batch_size):
                    X_train, y_train = sess.run(
                        [X_train_batch_op, y_train_batch_op])
                    _, step_ce, step_summary, global_step_value = sess.run(
                        [train_op, CE_total, summary_op, global_step],
                        feed_dict={
                            X: X_train,
                            y: y_train,
                            training: True
                        })

                    train_writer.add_summary(step_summary, global_step_value)
                    print('epoch:{} step:{} loss_CE:{}'.format(
                        epoch + 1, global_step_value, step_ce))
                #feed the network with validation data
                for step in range(0, num_validation, flags.batch_size):
                    X_test, y_test = sess.run(
                        [X_validation_batch_op, y_validation_batch_op])
                    step_ce, step_summary = sess.run([CE_total, summary_op],
                                                     feed_dict={
                                                         X: X_test,
                                                         y: y_test,
                                                         training: False
                                                     })

                    validation_writer.add_summary(
                        step_summary,
                        epoch * (num_train // flags.batch_size) +
                        step // flags.batch_size * num_train // num_validation)
                    print('validation loss_CE:{}'.format(step_ce))
                saver.save(sess, '{}/model.ckpt'.format(flags.model_dir))
        finally:
            coord.request_stop()
            coord.join(threads)
            saver.save(sess, "{}/model.ckpt".format(flags.model_dir))
Ejemplo n.º 2
0
def main(flags):
    current_time = time.strftime("%m/%d/%H/%M/%S")
    train_logdir = os.path.join(flags.logdir, "image", current_time)
    test_logdir = os.path.join(flags.logdir, "test", current_time)

    train = pd.read_csv(flags.data_dir)


    num_train = train.shape[0]

    test = pd.read_csv(flags.test_dir)
    num_test = test.shape[0]

    tf.reset_default_graph()
    X = tf.placeholder(tf.float32, shape = [None,h,w,c_image],name = 'X')
    y = tf.placeholder(tf.float32,shape = [None,h,w,c_label], name = 'y')
    mode = tf.placeholder(tf.bool, name='mode')

    pred = model.unet(X,mode)


    if flags.is_cross_entropy:
        loss = loss_CE(pred,y)
        CE_op = loss_CE(pred, y)
        tf.summary.scalar("CE", CE_op)

    else:
        loss = -loss_IOU(pred,y)
        IOU_op = loss_IOU(pred, y)
        tf.summary.scalar('IOU:', IOU_op)

    global_step = tf.Variable(0, dtype=tf.int64, trainable=False, name='global_step')

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    learning_rate = tf.train.exponential_decay(flags.learning_rate, global_step,
                                               tf.cast(num_train / flags.batch_size * flags.decay_step, tf.int32),
                                               flags.decay_rate, staircase=True)

    with tf.control_dependencies(update_ops):
        training_op = train_op(loss,learning_rate)


    train_csv = tf.train.string_input_producer(['data_image.csv'])
    test_csv = tf.train.string_input_producer(['data_test.csv'])

    train_image, train_label = read_csv(train_csv,augmentation=True)
    test_image, test_label = read_csv(test_csv,augmentation=False)

    #batch_size是返回的一个batch样本集的样本个数。capacity是队列中的容量
    X_train_batch_op, y_train_batch_op = tf.train.shuffle_batch([train_image, train_label],batch_size = flags.batch_size,
                                              capacity = flags.batch_size*5,min_after_dequeue = flags.batch_size*2,
                                              allow_smaller_final_batch = True)

    X_test_batch_op, y_test_batch_op = tf.train.batch([test_image, test_label],batch_size = flags.batch_size,
                                                        capacity = flags.batch_size*2,allow_smaller_final_batch = True)



    print('Shuffle batch done')
    #tf.summary.scalar('loss/Cross_entropy', CE_op)

    tf.add_to_collection('inputs', X)
    tf.add_to_collection('inputs', mode)
    tf.add_to_collection('pred', pred)

    tf.summary.image('Input Image:', X)
    tf.summary.image('Label:', y)
    tf.summary.image('Predicted Image:', pred)

    tf.summary.scalar("learning_rate", learning_rate)

    # 添加任意shape的Tensor,统计这个Tensor的取值分布
    tf.summary.histogram('Predicted Image:', pred)


    #添加一个操作,代表执行所有summary操作,这样可以避免人工执行每一个summary op
    summary_op = tf.summary.merge_all()

    with tf.Session() as sess:
        train_writer = tf.summary.FileWriter(train_logdir, sess.graph)
        test_writer = tf.summary.FileWriter(test_logdir)

        init = tf.global_variables_initializer()
        sess.run(init)


        saver = tf.train.Saver()
        if os.path.exists(flags.model_dir) and tf.train.checkpoint_exists(flags.model_dir):
            latest_check_point = tf.train.latest_checkpoint(flags.model_dir)
            saver.restore(sess, latest_check_point)

        else:
            print('No model')
            try:
                os.rmdir(flags.model_dir)
            except Exception as e:
                print(e)
            os.mkdir(flags.model_dir)

        try:
            #global_step = tf.train.get_global_step(sess.graph)

            #使用tf.train.string_input_producer(epoch_size, shuffle=False),会默认将QueueRunner添加到全局图中,
            #我们必须使用tf.train.start_queue_runners(sess=sess),去启动该线程。要在session当中将该线程开启,不然就会挂起。然后使用coord= tf.train.Coordinator()去做一些线程的同步工作,
            #否则会出现运行到sess.run一直卡住不动的情况。
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)

            for epoch in range(flags.epochs):
                for step in range(0,num_train,flags.batch_size):
                    X_train, y_train = sess.run([X_train_batch_op,y_train_batch_op])

                    if flags.is_cross_entropy:

                        _,step_ce,step_summary,global_step_value = sess.run([training_op,CE_op,summary_op,global_step],feed_dict={X:X_train,y:y_train,mode:True})

                        train_writer.add_summary(step_summary,global_step_value)
                        print('epoch:{} step:{} loss_CE:{}'.format(epoch+1, global_step_value, step_ce))

                    else:
                        _,step_iou,step_summary,global_step_value = sess.run([training_op, IOU_op, summary_op, global_step],feed_dict={X: X_train, y: y_train, mode: True})

                        train_writer.add_summary(step_summary, global_step_value)
                        print('epoch:{} step:{} loss_IOU:{}'.format(epoch + 1, global_step_value, step_iou))

                for step in range(0,num_test,flags.batch_size):
                    if flags.is_cross_entropy:
                        X_test, y_test = sess.run([X_test_batch_op,y_test_batch_op])
                        step_ce,step_summary = sess.run([CE_op,summary_op],feed_dict={X: X_test,y:y_test,mode:False})

                        test_writer.add_summary(step_summary,epoch * (num_train // flags.batch_size) + step // flags.batch_size * num_train // num_test)
                        print('Test loss_CE:{}'.format(step_ce))

                    else:
                        X_test, y_test = sess.run([X_test_batch_op, y_test_batch_op])
                        step_iou,step_summary = sess.run([IOU_op,summary_op],feed_dict={X: X_test, y: y_test, mode: False})

                        test_writer.add_summary(step_summary,epoch * (num_train // flags.batch_size) + step // flags.batch_size * num_train // num_test)
                        print('Test loss_IOU:{}'.format(step_iou))

            saver.save(sess, '{}/model.ckpt'.format(flags.model_dir))

        finally:
            coord.request_stop()
            coord.join(threads)
            saver.save(sess, "{}/model.ckpt".format(flags.model_dir))
Ejemplo n.º 3
0
def main(flags):
    current_time = time.strftime("%m/%d/%H/%M/%S")
    train_logdir = os.path.join(flags.logdir, "pig", current_time)
    test_logdir = os.path.join(flags.logdir, "test", current_time)

    train = pd.read_csv(flags.data_dir)

    num_train = train.shape[0]

    test = pd.read_csv(flags.test_dir)
    num_test = test.shape[0]

    tf.reset_default_graph()
    X = tf.placeholder(tf.float32,
                       shape=[flags.batch_size, h, w, c_image],
                       name='X')
    y = tf.placeholder(tf.float32,
                       shape=[flags.batch_size, h, w, c_label],
                       name='y')
    mode = tf.placeholder(tf.bool, name='mode')

    score_dsn6_up, score_dsn5_up, score_dsn4_up, score_dsn3_up, score_dsn2_up, score_dsn1_up, upscore_fuse = model.unet(
        X, mode)

    #print(score_dsn6_up.get_shape().as_list())

    loss6 = loss_CE(score_dsn6_up, y)
    loss5 = loss_CE(score_dsn5_up, y)
    loss4 = loss_CE(score_dsn4_up, y)
    loss3 = loss_CE(score_dsn3_up, y)
    loss2 = loss_CE(score_dsn2_up, y)
    loss1 = loss_CE(score_dsn1_up, y)
    loss_fuse = loss_CE(upscore_fuse, y)
    tf.summary.scalar("CE6", loss6)
    tf.summary.scalar("CE5", loss5)
    tf.summary.scalar("CE4", loss4)
    tf.summary.scalar("CE3", loss3)
    tf.summary.scalar("CE2", loss2)
    tf.summary.scalar("CE1", loss1)
    tf.summary.scalar("CE_fuse", loss_fuse)

    Loss = loss6 + loss5 + loss4 + loss3 + loss2 + 2 * loss1 + loss_fuse
    tf.summary.scalar("CE_total", Loss)

    global_step = tf.Variable(0,
                              dtype=tf.int64,
                              trainable=False,
                              name='global_step')

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    learning_rate = tf.train.exponential_decay(flags.learning_rate,
                                               global_step,
                                               decay_steps=flags.decay_step,
                                               decay_rate=flags.decay_rate,
                                               staircase=True)

    with tf.control_dependencies(update_ops):
        training_op = train_op(Loss, learning_rate)

    train_csv = tf.train.string_input_producer(['pig1.csv'])
    test_csv = tf.train.string_input_producer(['pigtest1.csv'])

    train_image, train_label = read_csv(train_csv, augmentation=True)
    test_image, test_label = read_csv(test_csv, augmentation=False)

    #batch_size是返回的一个batch样本集的样本个数。capacity是队列中的容量
    X_train_batch_op, y_train_batch_op = tf.train.shuffle_batch(
        [train_image, train_label],
        batch_size=flags.batch_size,
        capacity=flags.batch_size * 5,
        min_after_dequeue=flags.batch_size * 2,
        allow_smaller_final_batch=True)

    X_test_batch_op, y_test_batch_op = tf.train.batch(
        [test_image, test_label],
        batch_size=flags.batch_size,
        capacity=flags.batch_size * 2,
        allow_smaller_final_batch=True)

    print('Shuffle batch done')
    #tf.summary.scalar('loss/Cross_entropy', CE_op)
    score_dsn6_up = tf.nn.sigmoid(score_dsn6_up)
    score_dsn5_up = tf.nn.sigmoid(score_dsn5_up)
    score_dsn4_up = tf.nn.sigmoid(score_dsn4_up)
    score_dsn3_up = tf.nn.sigmoid(score_dsn3_up)
    score_dsn2_up = tf.nn.sigmoid(score_dsn2_up)
    score_dsn1_up = tf.nn.sigmoid(score_dsn1_up)
    upscore_fuse = tf.nn.sigmoid(upscore_fuse)
    print(upscore_fuse.get_shape().as_list())

    tf.add_to_collection('inputs', X)
    tf.add_to_collection('inputs', mode)
    tf.add_to_collection('score_dsn6_up', score_dsn6_up)
    tf.add_to_collection('score_dsn5_up', score_dsn5_up)
    tf.add_to_collection('score_dsn4_up', score_dsn4_up)
    tf.add_to_collection('score_dsn3_up', score_dsn3_up)
    tf.add_to_collection('score_dsn2_up', score_dsn2_up)
    tf.add_to_collection('score_dsn1_up', score_dsn1_up)
    tf.add_to_collection('upscore_fuse', upscore_fuse)

    tf.summary.image('Input Image:', X)
    tf.summary.image('Label:', y)
    tf.summary.image('score_dsn6_up:', score_dsn6_up)
    tf.summary.image('score_dsn5_up:', score_dsn5_up)
    tf.summary.image('score_dsn4_up:', score_dsn4_up)
    tf.summary.image('score_dsn3_up:', score_dsn3_up)
    tf.summary.image('score_dsn2_up:', score_dsn2_up)
    tf.summary.image('score_dsn1_up:', score_dsn1_up)
    tf.summary.image('upscore_fuse:', upscore_fuse)

    tf.summary.scalar("learning_rate", learning_rate)

    # 添加任意shape的Tensor,统计这个Tensor的取值分布
    tf.summary.histogram('score_dsn1_up:', score_dsn1_up)
    tf.summary.histogram('score_dsn2_up:', score_dsn2_up)
    tf.summary.histogram('score_dsn3_up:', score_dsn3_up)
    tf.summary.histogram('score_dsn4_up:', score_dsn4_up)
    tf.summary.histogram('score_dsn5_up:', score_dsn5_up)
    tf.summary.histogram('score_dsn6_up:', score_dsn6_up)
    tf.summary.histogram('upscore_fuse:', upscore_fuse)

    #添加一个操作,代表执行所有summary操作,这样可以避免人工执行每一个summary op
    summary_op = tf.summary.merge_all()

    with tf.Session() as sess:
        train_writer = tf.summary.FileWriter(train_logdir, sess.graph)
        test_writer = tf.summary.FileWriter(test_logdir)

        init = tf.global_variables_initializer()
        sess.run(init)

        saver = tf.train.Saver()
        # if not os.listdir(flags.model_dir):
        #     print('No model')
        #     try:
        #         os.rmdir(flags.model_dir)
        #     except Exception as e:
        #         print(e)
        #     os.mkdir(flags.model_dir)
        # else:
        #     latest_check_point = tf.train.latest_checkpoint(flags.model_dir)
        #     saver.restore(sess, latest_check_point)
        if os.path.exists(flags.model_dir) and tf.train.checkpoint_exists(
                flags.model_dir):
            latest_check_point = tf.train.latest_checkpoint(flags.model_dir)
            saver.restore(sess, latest_check_point)

        else:
            print('No model')
            try:
                os.rmdir(flags.model_dir)
            except Exception as e:
                print(e)
            os.mkdir(flags.model_dir)

        try:
            #global_step = tf.train.get_global_step(sess.graph)

            #使用tf.train.string_input_producer(epoch_size, shuffle=False),会默认将QueueRunner添加到全局图中,
            #我们必须使用tf.train.start_queue_runners(sess=sess),去启动该线程。要在session当中将该线程开启,不然就会挂起。然后使用coord= tf.train.Coordinator()去做一些线程的同步工作,
            #否则会出现运行到sess.run一直卡住不动的情况。
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)

            for epoch in range(flags.epochs):
                for step in range(0, num_train, flags.batch_size):
                    X_train, y_train = sess.run(
                        [X_train_batch_op, y_train_batch_op])
                    _, step_ce, step_summary, global_step_value = sess.run(
                        [training_op, Loss, summary_op, global_step],
                        feed_dict={
                            X: X_train,
                            y: y_train,
                            mode: True
                        })

                    train_writer.add_summary(step_summary, global_step_value)
                    print('epoch:{} step:{} loss_CE:{}'.format(
                        epoch + 1, global_step_value, step_ce))
                for step in range(0, num_test, flags.batch_size):
                    X_test, y_test = sess.run(
                        [X_test_batch_op, y_test_batch_op])
                    step_ce, step_summary = sess.run([Loss, summary_op],
                                                     feed_dict={
                                                         X: X_test,
                                                         y: y_test,
                                                         mode: False
                                                     })

                    test_writer.add_summary(
                        step_summary,
                        epoch * (num_train // flags.batch_size) +
                        step // flags.batch_size * num_train // num_test)
                    print('Test loss_CE:{}'.format(step_ce))
                saver.save(sess, '{}/model.ckpt'.format(flags.model_dir))

        finally:
            coord.request_stop()
            coord.join(threads)
            saver.save(sess, "{}/model.ckpt".format(flags.model_dir))