Esempio n. 1
0
def train(net_factory,
          prefix,
          end_epoch,
          base_dir,
          log_dir,
          display=200,
          base_lr=0.01,
          quantize=True,
          ckpt=None,
          optimizer='momentum'):
    """
    train PNet/RNet/ONet
    :param net_factory:
    :param prefix: model path
    :param end_epoch:
    :param dataset:
    :param display:
    :param base_lr:
    :return:
    """
    print('start training: ....')
    net = prefix.split('/')[-1]
    #label file
    label_file = os.path.join(base_dir, 'train_%s_landmark.txt' % net)
    #label_file = os.path.join(base_dir,'landmark_12_few.txt')
    print(label_file)
    f = open(label_file, 'r')
    # get number of training examples
    num = len(f.readlines())
    print("Total size of the dataset is: ", num)
    print(prefix)

    #PNet use this method to get data
    if net == 'PNet':
        #dataset_dir = os.path.join(base_dir,'train_%s_ALL.tfrecord_shuffle' % net)
        dataset_dir = os.path.join(base_dir,
                                   'train_%s_landmark.tfrecord_shuffle' % net)
        print('dataset dir is:', dataset_dir, 'batch_size = ',
              config.BATCH_SIZE)
        image_batch, label_batch, bbox_batch, landmark_batch = read_single_tfrecord(
            dataset_dir, config.BATCH_SIZE, net, no_landmarks)

    #RNet use 3 tfrecords to get data
    else:
        pos_dir = os.path.join(base_dir, 'pos_landmark.tfrecord_shuffle')
        part_dir = os.path.join(base_dir, 'part_landmark.tfrecord_shuffle')
        neg_dir = os.path.join(base_dir, 'neg_landmark.tfrecord_shuffle')
        #landmark_dir = os.path.join(base_dir,'landmark_landmark.tfrecord_shuffle')
        landmark_dir = os.path.join(base_dir,
                                    'landmark_landmark.tfrecord_shuffle')
        dataset_dirs = [pos_dir, part_dir, neg_dir, landmark_dir]
        pos_radio = 1.0 / 6
        part_radio = 1.0 / 6
        landmark_radio = 1.0 / 6
        neg_radio = 3.0 / 6
        pos_batch_size = int(np.ceil(config.BATCH_SIZE * pos_radio))
        assert pos_batch_size != 0, "Batch Size Error "
        part_batch_size = int(np.ceil(config.BATCH_SIZE * part_radio))
        assert part_batch_size != 0, "Batch Size Error "
        neg_batch_size = int(np.ceil(config.BATCH_SIZE * neg_radio))
        assert neg_batch_size != 0, "Batch Size Error "
        landmark_batch_size = int(np.ceil(config.BATCH_SIZE * landmark_radio))
        assert landmark_batch_size != 0, "Batch Size Error "
        batch_sizes = [
            pos_batch_size, part_batch_size, neg_batch_size,
            landmark_batch_size
        ]
        #print('batch_size is:', batch_sizes)
        image_batch, label_batch, bbox_batch, landmark_batch = read_multi_tfrecords(
            dataset_dirs, batch_sizes, net, no_landmarks)

    #landmark_dir
    if net == 'PNet':
        image_size = 12
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 0.5
    elif net == 'RNet':
        image_size = 24
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 0.5
    else:
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 1
        image_size = 48

    #define placeholder
    input_image = tf.placeholder(
        tf.float32,
        shape=[config.BATCH_SIZE, image_size, image_size, 3],
        name='input_image')
    label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label')
    bbox_target = tf.placeholder(tf.float32,
                                 shape=[config.BATCH_SIZE, 4],
                                 name='bbox_target')
    landmark_target = tf.placeholder(
        tf.float32,
        shape=[config.BATCH_SIZE, no_landmarks * 2],
        name='landmark_target')
    #get loss and accuracy
    input_image = image_color_distort(input_image)
    cls_loss_op, bbox_loss_op, landmark_loss_op, L2_loss_op, accuracy_op, landmark_pred = net_factory(
        input_image, label, bbox_target, landmark_target, training=True)
    #train,update learning rate(3 loss)
    # count_nan_op = count_nan(landmark_pred)
    total_loss_op = radio_cls_loss * cls_loss_op + radio_bbox_loss * bbox_loss_op + radio_landmark_loss * landmark_loss_op + L2_loss_op
    train_op, lr_op = train_model(base_lr, total_loss_op, num, quantize,
                                  optimizer)

    # init
    sess = tf.Session()

    #save model
    saver = tf.train.Saver(max_to_keep=10)
    step = 0
    if ckpt is not None:
        saver.restore(sess, ckpt)
        # get last global step
        step = int(os.path.basename(ckpt).split('-')[1])
        print('restored from last step = ', step)
    else:
        init = tf.global_variables_initializer()
        sess.run(init)

    #visualize some variables
    tf.summary.scalar("cls_loss", cls_loss_op)  #cls_loss
    tf.summary.scalar("bbox_loss", bbox_loss_op)  #bbox_loss
    tf.summary.scalar("landmark_loss", landmark_loss_op)  #landmark_loss
    tf.summary.scalar("cls_accuracy", accuracy_op)  #cls_acc
    tf.summary.scalar(
        "total_loss", total_loss_op
    )  #cls_loss, bbox loss, landmark loss and L2 loss add together
    summary_op = tf.summary.merge_all()
    logs_dir = os.path.join(log_dir, net)
    if os.path.exists(logs_dir) == False:
        os.mkdir(logs_dir)
    writer = tf.summary.FileWriter(logs_dir, sess.graph)
    projector_config = projector.ProjectorConfig()
    projector.visualize_embeddings(writer, projector_config)
    #begin
    coord = tf.train.Coordinator()
    #begin enqueue thread
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    # i = 0

    #total steps
    step_per_epoch = int(num / config.BATCH_SIZE + 1)
    print('step_per_epoch = ', step_per_epoch)
    MAX_STEP = step_per_epoch * end_epoch
    epoch = 0
    sess.graph.finalize()
    current_total_loss = 100000
    try:
        for i in range(MAX_STEP):
            # i = i + 1
            # j = i
            step = step + 1
            if coord.should_stop():
                break
            # print ('train step = ', step, image_batch.shape, bbox_batch.shape, landmark_batch.shape)
            image_batch_array, label_batch_array, bbox_batch_array, landmark_batch_array = sess.run(
                [image_batch, label_batch, bbox_batch, landmark_batch])
            #random flip
            # print('after batch array')
            image_batch_array, landmark_batch_array = random_flip_images(
                image_batch_array, label_batch_array, landmark_batch_array)

            # print('->>>>> 1')
            _, _, summary = sess.run(
                [train_op, lr_op, summary_op],
                feed_dict={
                    input_image: image_batch_array,
                    label: label_batch_array,
                    bbox_target: bbox_batch_array,
                    landmark_target: landmark_batch_array
                })

            if (step + 1) % display == 0:
                #acc = accuracy(cls_pred, labels_batch)
                # print('->>>>> 2')
                cls_loss, bbox_loss, landmark_loss, L2_loss, lr, acc, landmark_pred_val = sess.run(
                    [
                        cls_loss_op, bbox_loss_op, landmark_loss_op,
                        L2_loss_op, lr_op, accuracy_op, landmark_pred
                    ],
                    feed_dict={
                        input_image: image_batch_array,
                        label: label_batch_array,
                        bbox_target: bbox_batch_array,
                        landmark_target: landmark_batch_array
                    })
                if math.isnan(landmark_loss):
                    print('break, landmark loss is nan', landmark_loss)
                    print('landmark pred val ', landmark_pred_val)
                    # nan_count = sess.run([count_nan_op], feed_dict={landmark_pred: landmark_pred_val})
                    print('no of nan in landmark_pred_val',
                          count_nan(landmark_pred_val))
                    print('no of nan in landmark_target',
                          count_nan(landmark_batch_array))
                    # print('other metrics ', square_error_val, k_index_val, valid_inds_val)
                    break

                total_loss = radio_cls_loss * cls_loss + radio_bbox_loss * bbox_loss + radio_landmark_loss * landmark_loss + L2_loss
                # landmark loss: %4f,
                print(
                    "%s : Step: %d/%d, accuracy: %3f, cls loss: %4f, bbox loss: %4f,Landmark loss :%4f,L2 loss: %4f, Total Loss: %4f ,lr:%f "
                    % (datetime.now(), step + 1, MAX_STEP, acc, cls_loss,
                       bbox_loss, landmark_loss, L2_loss, total_loss, lr))
                if total_loss < current_total_loss:
                    current_total_loss = total_loss
                    path_prefix = saver.save(sess, prefix, global_step=step)
                    print('Total loss improved, save model ', path_prefix)
            # save every end of epochs
            # if i > 0 and i % step_per_epoch == 0:
            #     path_prefix = saver.save(sess, prefix, global_step=step)
            #     print('Save end of epoch, path prefix is :', path_prefix)
            writer.add_summary(summary, global_step=step)
    except tf.errors.OutOfRangeError:
        print("完成!!!")
    finally:
        coord.request_stop()
        writer.close()
    coord.join(threads)
    sess.close()
Esempio n. 2
0
def train(net_factory,
          model_save_path,
          max_epoch,
          tfrecord_path,
          display=200,
          base_lr=0.01):
    """ train PNet/RNet/ONet
    :param net_factory: 对应网络的模型
    :param model_save_path: 模型参数保存路径
    :param max_epoch: 迭代次数
    :param tfrecord_path: pos, neg, part, landmark 4类标签数据tfrecord所在路径
    :param display:
    :param base_lr:
    :return:
    """
    net = model_save_path.split('/')[-1]
    # label file
    label_file = tfrecord_path[0]
    # label_file = os.path.join(base_dir,'landmark_12_few.txt')
    print(label_file)
    f = open(label_file, 'r')
    # get number of training examples
    num = len(f.readlines())
    print("Total size of the dataset is: ", num)
    print(model_save_path)

    # PNet use this method to get data
    if net == 'PNet':
        print('dataset dir is:', tfrecord_path[1])
        image_batch, label_batch, bbox_batch, landmark_batch = \
            read_single_tfrecord(tfrecord_path[1], config.BATCH_SIZE, net)
    # RNet及ONet use 3 tfrecords to get data
    else:
        pos_dir = tfrecord_path[1]
        part_dir = tfrecord_path[2]
        neg_dir = tfrecord_path[3]
        landmark_dir = tfrecord_path[4]
        dataset_dirs = [pos_dir, part_dir, neg_dir, landmark_dir]
        pos_radio = 1.0 / 6
        part_radio = 1.0 / 6
        landmark_radio = 1.0 / 6
        neg_radio = 3.0 / 6
        pos_batch_size = int(np.ceil(config.BATCH_SIZE * pos_radio))
        assert pos_batch_size != 0, "Batch Size Error "
        part_batch_size = int(np.ceil(config.BATCH_SIZE * part_radio))
        assert part_batch_size != 0, "Batch Size Error "
        neg_batch_size = int(np.ceil(config.BATCH_SIZE * neg_radio))
        assert neg_batch_size != 0, "Batch Size Error "
        landmark_batch_size = int(np.ceil(config.BATCH_SIZE * landmark_radio))
        assert landmark_batch_size != 0, "Batch Size Error "
        batch_sizes = [
            pos_batch_size, part_batch_size, neg_batch_size,
            landmark_batch_size
        ]
        # print('batch_size is:', batch_sizes)
        image_batch, label_batch, bbox_batch, landmark_batch = read_multi_tfrecords(
            dataset_dirs, batch_sizes, net)

    path_config = PathConfiguration().config
    if net == 'PNet':
        image_size = 12
        radio_cls_loss = 1.0
        radio_bbox_loss = 1.0
        radio_landmark_loss = 0.5
        logs_dir = path_config.pnet_log_path
    elif net == 'RNet':
        image_size = 24
        radio_cls_loss = 1.0
        radio_bbox_loss = 1.0
        radio_landmark_loss = 0.5
        logs_dir = path_config.rnet_log_path
    else:
        radio_cls_loss = 1.0
        radio_bbox_loss = 1.0
        radio_landmark_loss = 1.0
        image_size = 48
        logs_dir = path_config.onet_log_path

    # define placeholder
    input_image = tf.placeholder(
        tf.float32,
        shape=[config.BATCH_SIZE, image_size, image_size, 3],
        name='input_image')
    label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label')
    bbox_target = tf.placeholder(tf.float32,
                                 shape=[config.BATCH_SIZE, 4],
                                 name='bbox_target')
    landmark_target = tf.placeholder(tf.float32,
                                     shape=[config.BATCH_SIZE, 10],
                                     name='landmark_target')
    # get loss and accuracy
    input_image = image_color_distort(input_image)
    cls_loss_op, bbox_loss_op, landmark_loss_op, L2_loss_op, accuracy_op = net_factory(
        input_image, label, bbox_target, landmark_target, training=True)
    # train,update learning rate(3 loss)
    total_loss_op = radio_cls_loss * cls_loss_op + radio_bbox_loss * bbox_loss_op
    total_loss_op = total_loss_op + radio_landmark_loss * landmark_loss_op + L2_loss_op
    train_op, lr_op = train_model(base_lr, total_loss_op, num)
    # init
    os.environ["CUDA_VISIBLE_DEVICES"] = config.VISIBLE_GPU  # '0,1,2,3'
    tf.device('/gpu:{}'.format(config.GPU))
    init = tf.global_variables_initializer()
    sess = tf.Session()

    # save model
    saver = tf.train.Saver(max_to_keep=0)
    sess.run(init)

    # visualize some variables
    tf.summary.scalar("cls_loss", cls_loss_op)  # cls_loss
    tf.summary.scalar("bbox_loss", bbox_loss_op)  # bbox_loss
    tf.summary.scalar("landmark_loss", landmark_loss_op)  # landmark_loss
    tf.summary.scalar("cls_accuracy", accuracy_op)  # cls_acc
    tf.summary.scalar(
        "total_loss", total_loss_op
    )  # cls_loss, bbox loss, landmark loss and L2 loss add together
    summary_op = tf.summary.merge_all()

    if not os.path.exists(logs_dir):
        os.makedirs(logs_dir)
    writer = tf.summary.FileWriter(logs_dir, sess.graph)
    projector_config = projector.ProjectorConfig()
    projector.visualize_embeddings(writer, projector_config)
    # begin
    coord = tf.train.Coordinator()
    # begin enqueue thread
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    i = 0
    # total steps
    MAX_STEP = int(num / config.BATCH_SIZE + 1) * max_epoch
    epoch = 0
    sess.graph.finalize()
    try:
        for step in range(MAX_STEP):
            i = i + 1
            if coord.should_stop():
                break
            image_batch_array, label_batch_array, bbox_batch_array, landmark_batch_array = sess.run(
                [image_batch, label_batch, bbox_batch, landmark_batch])
            # random flip
            image_batch_array, landmark_batch_array = random_flip_images(
                image_batch_array, label_batch_array, landmark_batch_array)
            '''
            print('im here')
            print(image_batch_array.shape)
            print(label_batch_array.shape)
            print(bbox_batch_array.shape)
            print(landmark_batch_array.shape)
            print(label_batch_array[0])
            print(bbox_batch_array[0])
            print(landmark_batch_array[0])
            '''

            _, _, summary = sess.run(
                [train_op, lr_op, summary_op],
                feed_dict={
                    input_image: image_batch_array,
                    label: label_batch_array,
                    bbox_target: bbox_batch_array,
                    landmark_target: landmark_batch_array
                })

            if (step + 1) % display == 0:
                cls_loss, bbox_loss, landmark_loss, L2_loss, lr, acc = sess.run(
                    [
                        cls_loss_op, bbox_loss_op, landmark_loss_op,
                        L2_loss_op, lr_op, accuracy_op
                    ],
                    feed_dict={
                        input_image: image_batch_array,
                        label: label_batch_array,
                        bbox_target: bbox_batch_array,
                        landmark_target: landmark_batch_array
                    })

                total_loss = radio_cls_loss * cls_loss + radio_bbox_loss * bbox_loss
                total_loss = total_loss + radio_landmark_loss * landmark_loss + L2_loss
                # landmark loss: %4f,
                print((
                    "%s - Step: %d/%d, accuracy: %3f, cls loss: %4f, bbox loss: %4f, "
                    +
                    "Landmark loss :%4f,L2 loss: %4f, Total Loss: %4f ,lr:%f ")
                      % (datetime.now(), step + 1, MAX_STEP, acc, cls_loss,
                         bbox_loss, landmark_loss, L2_loss, total_loss, lr))

            # save every two epochs
            if i * config.BATCH_SIZE > num * 2:
                epoch = epoch + 1
                i = 0
                path_prefix = saver.save(sess,
                                         model_save_path,
                                         global_step=epoch * 2)
                print('path prefix is :', path_prefix)
            writer.add_summary(summary, global_step=step)
    except tf.errors.OutOfRangeError:
        print("异常结束!!!")
    finally:
        print("完成!!!")
        coord.request_stop()
        writer.close()
    coord.join(threads)
    sess.close()
Esempio n. 3
0
def train(net_factory, prefix, end_epoch, base_dir, display=200, base_lr=0.01):
    """
    train PNet/RNet/ONet
    :param net_factory:
    :param prefix: model path
    :param end_epoch:
    :param dataset:
    :param display:
    :param base_lr:
    :return:
    """
    net = prefix.split('/')[-1]
    #label file
    label_file = os.path.join(base_dir, 'train_%s_landmark.txt' % net)
    #label_file = os.path.join(base_dir,'landmark_12_few.txt')
    print(label_file)
    f = open(label_file, 'r')
    # get number of training examples
    num = len(f.readlines())
    print("Total size of the dataset is: ", num)
    print(prefix)

    #PNet use this method to get data
    if net == 'PNet':
        #dataset_dir = os.path.join(base_dir,'train_%s_ALL.tfrecord_shuffle' % net)
        dataset_dir = os.path.join(base_dir,
                                   'train_%s_landmark.tfrecord_shuffle' % net)
        print('dataset dir is:', dataset_dir)
        image_batch, label_batch, bbox_batch, landmark_batch = read_single_tfrecord(
            dataset_dir, config.BATCH_SIZE, net)

    #RNet use 3 tfrecords to get data
    else:
        pos_dir = os.path.join(base_dir, 'pos_landmark.tfrecord_shuffle')
        part_dir = os.path.join(base_dir, 'part_landmark.tfrecord_shuffle')
        neg_dir = os.path.join(base_dir, 'neg_landmark.tfrecord_shuffle')
        #landmark_dir = os.path.join(base_dir,'landmark_landmark.tfrecord_shuffle')
        landmark_dir = os.path.join('../../DATA/imglists/RNet',
                                    'landmark_landmark.tfrecord_shuffle')
        dataset_dirs = [pos_dir, part_dir, neg_dir, landmark_dir]
        pos_radio = 1.0 / 6
        part_radio = 1.0 / 6
        landmark_radio = 1.0 / 6
        neg_radio = 3.0 / 6
        pos_batch_size = int(np.ceil(config.BATCH_SIZE * pos_radio))
        assert pos_batch_size != 0, "Batch Size Error "
        part_batch_size = int(np.ceil(config.BATCH_SIZE * part_radio))
        assert part_batch_size != 0, "Batch Size Error "
        neg_batch_size = int(np.ceil(config.BATCH_SIZE * neg_radio))
        assert neg_batch_size != 0, "Batch Size Error "
        landmark_batch_size = int(np.ceil(config.BATCH_SIZE * landmark_radio))
        assert landmark_batch_size != 0, "Batch Size Error "
        batch_sizes = [
            pos_batch_size, part_batch_size, neg_batch_size,
            landmark_batch_size
        ]
        #print('batch_size is:', batch_sizes)
        image_batch, label_batch, bbox_batch, landmark_batch = read_multi_tfrecords(
            dataset_dirs, batch_sizes, net)

    #landmark_dir
    if net == 'PNet':
        image_size = 12
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 0.5
    elif net == 'RNet':
        image_size = 24
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 0.5
    else:
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 1
        image_size = 48

    #define placeholder
    input_image = tf.placeholder(
        tf.float32,
        shape=[config.BATCH_SIZE, image_size, image_size, 3],
        name='input_image')
    label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label')
    bbox_target = tf.placeholder(tf.float32,
                                 shape=[config.BATCH_SIZE, 4],
                                 name='bbox_target')
    landmark_target = tf.placeholder(tf.float32,
                                     shape=[config.BATCH_SIZE, 10],
                                     name='landmark_target')
    #get loss and accuracy
    input_image = image_color_distort(input_image)
    cls_loss_op, bbox_loss_op, landmark_loss_op, L2_loss_op, accuracy_op = net_factory(
        input_image, label, bbox_target, landmark_target, training=True)
    #train,update learning rate(3 loss)
    total_loss_op = radio_cls_loss * cls_loss_op + radio_bbox_loss * bbox_loss_op + radio_landmark_loss * landmark_loss_op + L2_loss_op
    train_op, lr_op = train_model(base_lr, total_loss_op, num)
    # init
    init = tf.global_variables_initializer()
    sess = tf.Session()

    #save model
    saver = tf.train.Saver(max_to_keep=0)
    sess.run(init)

    #visualize some variables
    tf.summary.scalar("cls_loss", cls_loss_op)  #cls_loss
    tf.summary.scalar("bbox_loss", bbox_loss_op)  #bbox_loss
    tf.summary.scalar("landmark_loss", landmark_loss_op)  #landmark_loss
    tf.summary.scalar("cls_accuracy", accuracy_op)  #cls_acc
    tf.summary.scalar(
        "total_loss", total_loss_op
    )  #cls_loss, bbox loss, landmark loss and L2 loss add together
    summary_op = tf.summary.merge_all()
    logs_dir = "./logs/%s" % (net)
    if os.path.exists(logs_dir) == False:
        os.mkdir(logs_dir)
    writer = tf.summary.FileWriter(logs_dir, sess.graph)
    projector_config = projector.ProjectorConfig()
    projector.visualize_embeddings(writer, projector_config)
    #begin
    coord = tf.train.Coordinator()
    #begin enqueue thread
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    i = 0
    #total steps
    MAX_STEP = int(num / config.BATCH_SIZE + 1) * end_epoch
    epoch = 0
    sess.graph.finalize()
    try:

        for step in range(MAX_STEP):
            i = i + 1
            if coord.should_stop():
                break
            image_batch_array, label_batch_array, bbox_batch_array, landmark_batch_array = sess.run(
                [image_batch, label_batch, bbox_batch, landmark_batch])
            #random flip
            image_batch_array, landmark_batch_array = random_flip_images(
                image_batch_array, label_batch_array, landmark_batch_array)
            '''
            print('im here')
            print(image_batch_array.shape)
            print(label_batch_array.shape)
            print(bbox_batch_array.shape)
            print(landmark_batch_array.shape)
            print(label_batch_array[0])
            print(bbox_batch_array[0])
            print(landmark_batch_array[0])
            '''

            _, _, summary = sess.run(
                [train_op, lr_op, summary_op],
                feed_dict={
                    input_image: image_batch_array,
                    label: label_batch_array,
                    bbox_target: bbox_batch_array,
                    landmark_target: landmark_batch_array
                })

            if (step + 1) % display == 0:
                #acc = accuracy(cls_pred, labels_batch)
                cls_loss, bbox_loss, landmark_loss, L2_loss, lr, acc = sess.run(
                    [
                        cls_loss_op, bbox_loss_op, landmark_loss_op,
                        L2_loss_op, lr_op, accuracy_op
                    ],
                    feed_dict={
                        input_image: image_batch_array,
                        label: label_batch_array,
                        bbox_target: bbox_batch_array,
                        landmark_target: landmark_batch_array
                    })

                total_loss = radio_cls_loss * cls_loss + radio_bbox_loss * bbox_loss + radio_landmark_loss * landmark_loss + L2_loss
                # landmark loss: %4f,
                print(
                    "%s : Step: %d/%d, accuracy: %3f, cls loss: %4f, bbox loss: %4f,Landmark loss :%4f,L2 loss: %4f, Total Loss: %4f ,lr:%f "
                    % (datetime.now(), step + 1, MAX_STEP, acc, cls_loss,
                       bbox_loss, landmark_loss, L2_loss, total_loss, lr))

            #save every two epochs
            if i * config.BATCH_SIZE > num * 2:
                epoch = epoch + 1
                i = 0
                path_prefix = saver.save(sess, prefix, global_step=epoch * 2)
                print('path prefix is :', path_prefix)
            writer.add_summary(summary, global_step=step)
    except tf.errors.OutOfRangeError:
        print("完成!!!")
    finally:
        coord.request_stop()
        writer.close()
    coord.join(threads)
    sess.close()