Beispiel #1
0
def train(net_factory, prefix, end_epoch, base_dir, display=200, base_lr=0.01):
    net = prefix.split('/')[-1]

    label_file = os.path.join(base_dir, 'train_%s_landmark.txt' % net)
    print(label_file
          )  # label_file = os.path.join(base_dir,'landmark_12_few.txt')
    f = open(label_file, 'r')
    num = len(f.readlines())  # 142w个 数据
    print("Total datasets is: ", num)
    print(prefix)  # prefix == '../data/MTCNN_model/PNet_landmark/PNet'

    if net == 'PNet':  #PNet use this method to get data
        # dataset_dir = '../prepare_data/imglists/PNet\\train_PNet_landmark.tfrecord_shuffle'
        dataset_dir = os.path.join(base_dir,
                                   'train_%s_landmark.tfrecord_shuffle' % net)
        print(dataset_dir)
        # 一个batch == 4608, 从数据集中读取1个batch的pixel和,label
        image_batch, label_batch, bbox_batch, landmark_batch = read_single_tfrecord(
            dataset_dir, config.BATCH_SIZE, net)

    else:  #R Net use 3 tfrecords to get data
        pos_dir = os.path.join(base_dir, 'pos_landmark.tfrecord_shuffle')
        part_dir = os.path.join(base_dir, 'part_landmark.tfrecord_shuffle')
        neg_dir = os.path.join(base_dir, 'neg_landmark.tfrecord_shuffle')
        landmark_dir = os.path.join(base_dir,
                                    'landmark_landmark.tfrecord_shuffle')
        dataset_dirs = [pos_dir, part_dir, neg_dir, landmark_dir]
        pos_radio = 1.0 / 6
        part_radio = 1.0 / 6
        landmark_radio = 1.0 / 6
        neg_radio = 3.0 / 6
        pos_batch_size = int(np.ceil(config.BATCH_SIZE * pos_radio))
        assert pos_batch_size != 0, "Batch Size Error "
        part_batch_size = int(np.ceil(config.BATCH_SIZE * part_radio))
        assert part_batch_size != 0, "Batch Size Error "
        neg_batch_size = int(np.ceil(config.BATCH_SIZE * neg_radio))
        assert neg_batch_size != 0, "Batch Size Error "
        landmark_batch_size = int(np.ceil(config.BATCH_SIZE * landmark_radio))
        assert landmark_batch_size != 0, "Batch Size Error "
        batch_sizes = [
            pos_batch_size, part_batch_size, neg_batch_size,
            landmark_batch_size
        ]
        image_batch, label_batch, bbox_batch, landmark_batch = read_multi_tfrecords(
            dataset_dirs, batch_sizes, net)

    #landmark_dir
    if net == 'PNet':
        image_size = 12
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 0.5
        # cls_prob == classify probability
    elif net == 'RNet':
        image_size = 24
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 0.5
    else:
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 1.0
        image_size = 48

    #define placeholder
    input_image = tf.placeholder(
        tf.float32,
        shape=[config.BATCH_SIZE, image_size, image_size, 3],
        name='input_image')
    label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label')
    bbox_target = tf.placeholder(tf.float32,
                                 shape=[config.BATCH_SIZE, 4],
                                 name='bbox_target')
    landmark_target = tf.placeholder(tf.float32,
                                     shape=[config.BATCH_SIZE, 10],
                                     name='landmark_target')

    global_ = tf.Variable(tf.constant(0), trainable=False)

    #前向传播 class,regression
    cls_loss_op, bbox_loss_op, landmark_loss_op, L2_loss_op, accuracy_op = net_factory(
        input_image, label, bbox_target, landmark_target, training=True)
    #train,update learning rate(3 loss),为什么要对detection和alignment的损失 * 0.5
    train_op, lr_op = train_model(
        base_lr,
        radio_cls_loss * cls_loss_op + radio_bbox_loss * bbox_loss_op +
        radio_landmark_loss * landmark_loss_op + L2_loss_op, num, global_)
    # init
    init = tf.global_variables_initializer()
    sess = tf.Session()
    #save model
    saver = tf.train.Saver(max_to_keep=0)
    sess.run(init)
    # 数据的可视化 visualize some variables
    tf.summary.scalar("cls_loss", cls_loss_op)  #cls_loss
    tf.summary.scalar("bbox_loss", bbox_loss_op)  #bbox_loss
    tf.summary.scalar("landmark_loss", landmark_loss_op)  #landmark_loss
    tf.summary.scalar("cls_accuracy", accuracy_op)  #cls_acc
    tf.summary.scalar("learning rate", lr_op)  # learning rate
    summary_op = tf.summary.merge_all()
    logs_dir = "../logs/%s" % (net)
    if os.path.exists(logs_dir) == False:
        os.mkdir(logs_dir)  # logs文件夹的

    writer = tf.summary.FileWriter(logs_dir, sess.graph)
    # begin
    coord = tf.train.Coordinator()
    # begin enqueue thread 启动多个工作线程同时将多个tensor(训练数据)推送入文件名称队列中
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    i = 0
    #total steps
    MAX_STEP = int(num / config.BATCH_SIZE + 1) * end_epoch  # end_epoch训练周期?
    epoch = 0
    # 防止内存溢出?
    sess.graph.finalize()
    try:
        for step in range(MAX_STEP):
            i = i + 1
            # 使用 coord.should_stop()来查询是否应该终止所有线程,当文件队列(queue)中的所有文件都已经读取出列的时候,会抛出一个OutofRangeError的异常,这时候就应该停止Sesson中的所有线程
            if coord.should_stop():
                break
            image_batch_array, label_batch_array, bbox_batch_array, landmark_batch_array = sess.run(
                [image_batch, label_batch, bbox_batch, landmark_batch])
            #随机翻转图片random flip
            image_batch_array, landmark_batch_array = random_flip_images(
                image_batch_array, label_batch_array, landmark_batch_array)
            # 计算总的损失
            _, _, summary = sess.run(
                [train_op, lr_op, summary_op],
                feed_dict={
                    input_image: image_batch_array,
                    label: label_batch_array,
                    bbox_target: bbox_batch_array,
                    landmark_target: landmark_batch_array,
                    global_: step
                })
            # 每过100轮,打印loss
            if (step + 1) % display == 0:
                cls_loss, bbox_loss, landmark_loss, L2_loss, lr, acc = sess.run(
                    [
                        cls_loss_op, bbox_loss_op, landmark_loss_op,
                        L2_loss_op, lr_op, accuracy_op
                    ],
                    feed_dict={
                        input_image: image_batch_array,
                        label: label_batch_array,
                        bbox_target: bbox_batch_array,
                        landmark_target: landmark_batch_array,
                        global_: step
                    })

                print(
                    "%s : Step: %d, accuracy: %3f, cls loss: %4f, bbox loss: %4f, landmark loss: %4f,L2 loss: %4f, lr: %4f"
                    % (datetime.now(), step + 1, acc, cls_loss, bbox_loss,
                       landmark_loss, L2_loss, lr))
            # 每两个周期保存一次
            if i * config.BATCH_SIZE > num * 2:
                epoch = epoch + 1
                i = 0
                saver.save(sess, prefix, global_step=epoch * 2)
            if (step + 1) % display == 0:
                writer.add_summary(summary, global_step=step)
    except tf.errors.OutOfRangeError:
        print("完成!!!")
    finally:
        coord.request_stop()
        writer.close()
    coord.join(threads)
    sess.close()
Beispiel #2
0
def train(net_factory, prefix, end_epoch, base_dir,
          display=200, base_lr=0.01):
    """
    train PNet/RNet/ONet
    :param net_factory:
    :param prefix:
    :param end_epoch:16
    :param dataset:
    :param display:
    :param base_lr:
    :return:
    """
    net = prefix.split('/')[-1]
    #label file
    label_file = os.path.join(base_dir,'train_%s_landmark.txt' % net)
    #label_file = os.path.join(base_dir,'landmark_12_few.txt')
    print (label_file)
    f = open(label_file, 'r')
    num = len(f.readlines())
    print("Total datasets is: ", num)
    print (prefix)

    #PNet use this method to get data
    if net == 'PNet':
        #dataset_dir = os.path.join(base_dir,'train_%s_ALL.tfrecord_shuffle' % net)
        dataset_dir = os.path.join(base_dir,'train_%s_landmark.tfrecord_shuffle' % net)
        print (dataset_dir)
        image_batch, label_batch, bbox_batch,landmark_batch = read_single_tfrecord(dataset_dir, config.BATCH_SIZE, net)
        
    #RNet use 4 tfrecords to get data    
    else:
        pos_dir = os.path.join(base_dir,'pos_landmark.tfrecord_shuffle')
        part_dir = os.path.join(base_dir,'part_landmark.tfrecord_shuffle')
        neg_dir = os.path.join(base_dir,'neg_landmark.tfrecord_shuffle')
        landmark_dir = os.path.join(base_dir,'landmark_landmark.tfrecord_shuffle')
        dataset_dirs = [pos_dir,part_dir,neg_dir,landmark_dir]
        pos_radio = 1.0/6;part_radio = 1.0/6;landmark_radio=1.0/6;neg_radio=3.0/6
        pos_batch_size = int(np.ceil(config.BATCH_SIZE*pos_radio))
        assert pos_batch_size != 0,"Batch Size Error "
        part_batch_size = int(np.ceil(config.BATCH_SIZE*part_radio))
        assert part_batch_size != 0,"Batch Size Error "        
        neg_batch_size = int(np.ceil(config.BATCH_SIZE*neg_radio))
        assert neg_batch_size != 0,"Batch Size Error "
        landmark_batch_size = int(np.ceil(config.BATCH_SIZE*landmark_radio))
        assert landmark_batch_size != 0,"Batch Size Error "
        batch_sizes = [pos_batch_size,part_batch_size,neg_batch_size,landmark_batch_size]
        image_batch, label_batch, bbox_batch,landmark_batch = read_multi_tfrecords(dataset_dirs,batch_sizes, net)        
        
    #landmark_dir    
    if net == 'PNet':
        image_size = 12
        radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 0.5;
    elif net == 'RNet':
        image_size = 24
        radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 0.5;
    else:
        radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 1.0;
        image_size = 48
    
    #define placeholder
    input_image = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, image_size, image_size, 3], name='input_image')
    label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label')
    bbox_target = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, 4], name='bbox_target')
    landmark_target = tf.placeholder(tf.float32,shape=[config.BATCH_SIZE,10],name='landmark_target')
    #class,regression
    cls_loss_op,bbox_loss_op,landmark_loss_op,L2_loss_op,accuracy_op = net_factory(input_image, label, bbox_target,landmark_target,training=True)
    #train,update learning rate(3 loss)
    train_op, lr_op = train_model(base_lr, radio_cls_loss*cls_loss_op + radio_bbox_loss*bbox_loss_op + radio_landmark_loss*landmark_loss_op + L2_loss_op, num)
    # init
    init = tf.global_variables_initializer()
    sess = tf.Session()
    #save model
    saver = tf.train.Saver(max_to_keep=0)
    sess.run(init)
    #visualize some variables
    tf.summary.scalar("cls_loss",cls_loss_op)#cls_loss
    tf.summary.scalar("bbox_loss",bbox_loss_op)#bbox_loss
    tf.summary.scalar("landmark_loss",landmark_loss_op)#landmark_loss
    tf.summary.scalar("cls_accuracy",accuracy_op)#cls_acc
    summary_op = tf.summary.merge_all()
    logs_dir = "../logs/%s" %(net)
    if os.path.exists(logs_dir) == False:
        os.mkdir(logs_dir)
    writer = tf.summary.FileWriter(logs_dir,sess.graph)
    #begin 
    coord = tf.train.Coordinator()
    #begin enqueue thread
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    i = 0
    #total steps
    MAX_STEP = int(num / config.BATCH_SIZE + 1) * end_epoch
    epoch = 0
    sess.graph.finalize()    
    try:
        for step in range(MAX_STEP):
            i = i + 1
            if coord.should_stop():
                break
            image_batch_array, label_batch_array, bbox_batch_array,landmark_batch_array = sess.run([image_batch, label_batch, bbox_batch,landmark_batch])
            #random flip
            image_batch_array,landmark_batch_array = random_flip_images(image_batch_array,label_batch_array,landmark_batch_array)
            '''
            print image_batch_array.shape
            print label_batch_array.shape
            print bbox_batch_array.shape
            print landmark_batch_array.shape
            print label_batch_array[0]
            print bbox_batch_array[0]
            print landmark_batch_array[0]
            '''
            _,_,summary = sess.run([train_op, lr_op ,summary_op], feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array,landmark_target:landmark_batch_array})
            
            if (step+1) % display == 0:
                #acc = accuracy(cls_pred, labels_batch)
                cls_loss, bbox_loss,landmark_loss,L2_loss,lr,acc = sess.run([cls_loss_op, bbox_loss_op,landmark_loss_op,L2_loss_op,lr_op,accuracy_op],
                                                             feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, landmark_target: landmark_batch_array})                
                print("%s : Step: %d, accuracy: %3f, cls loss: %4f, bbox loss: %4f, landmark loss: %4f,L2 loss: %4f,lr:%f " % (
                datetime.now(), step+1, acc, cls_loss, bbox_loss, landmark_loss, L2_loss, lr))
            #save every two epochs
            if i * config.BATCH_SIZE > num*2:
                epoch = epoch + 1
                i = 0
                saver.save(sess, prefix, global_step=epoch*2)
            writer.add_summary(summary,global_step=step)
    except tf.errors.OutOfRangeError:
        print("完成!!!")
    finally:
        coord.request_stop()
        writer.close()
    coord.join(threads)
    sess.close()
Beispiel #3
0
def train(net_factory, prefix, end_epoch, base_dir,
          display=200, base_lr=0.01):
    """
    train PNet/RNet/ONet
    :param net_factory:
    :param prefix:
    :param end_epoch:16
    :param dataset:
    :param display:
    :param base_lr:
    :return:
    """
    net = prefix.split('/')[-1]
    #label file
    label_file = os.path.join(base_dir,'train_%s_landmark.txt' % net)
    #label_file = os.path.join(base_dir,'landmark_12_few.txt')
    print label_file 
    f = open(label_file, 'r')
    num = len(f.readlines())
    print("Total datasets is: ", num)
    print prefix

    #PNet use this method to get data
    if net == 'PNet':
        #dataset_dir = os.path.join(base_dir,'train_%s_ALL.tfrecord_shuffle' % net)
        dataset_dir = os.path.join(base_dir,'train_%s_landmark.tfrecord_shuffle' % net)
        print dataset_dir
        image_batch, label_batch, bbox_batch,landmark_batch = read_single_tfrecord(dataset_dir, config.BATCH_SIZE, net)
        
    #RNet use 3 tfrecords to get data    
    else:
        pos_dir = os.path.join(base_dir,'pos_landmark.tfrecord_shuffle')
        part_dir = os.path.join(base_dir,'part_landmark.tfrecord_shuffle')
        neg_dir = os.path.join(base_dir,'neg_landmark.tfrecord_shuffle')
        landmark_dir = os.path.join(base_dir,'landmark_landmark.tfrecord_shuffle')
        dataset_dirs = [pos_dir,part_dir,neg_dir,landmark_dir]
        pos_radio = 1.0/6;part_radio = 1.0/6;landmark_radio=1.0/6;neg_radio=3.0/6
        pos_batch_size = int(np.ceil(config.BATCH_SIZE*pos_radio))
        assert pos_batch_size != 0,"Batch Size Error "
        part_batch_size = int(np.ceil(config.BATCH_SIZE*part_radio))
        assert part_batch_size != 0,"Batch Size Error "        
        neg_batch_size = int(np.ceil(config.BATCH_SIZE*neg_radio))
        assert neg_batch_size != 0,"Batch Size Error "
        landmark_batch_size = int(np.ceil(config.BATCH_SIZE*landmark_radio))
        assert landmark_batch_size != 0,"Batch Size Error "
        batch_sizes = [pos_batch_size,part_batch_size,neg_batch_size,landmark_batch_size]
        image_batch, label_batch, bbox_batch,landmark_batch = read_multi_tfrecords(dataset_dirs,batch_sizes, net)        
        
    #landmark_dir    
    if net == 'PNet':
        image_size = 12
        radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 0.5;
    elif net == 'RNet':
        image_size = 24
        radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 0.5;
    else:
        radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 1.0;
        image_size = 48
    
    #define placeholder
    input_image = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, image_size, image_size, 3], name='input_image')
    label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label')
    bbox_target = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, 4], name='bbox_target')
    landmark_target = tf.placeholder(tf.float32,shape=[config.BATCH_SIZE,10],name='landmark_target')
    #class,regression
    cls_loss_op,bbox_loss_op,landmark_loss_op,L2_loss_op,accuracy_op = net_factory(input_image, label, bbox_target,landmark_target,training=True)
    #train,update learning rate(3 loss)
    train_op, lr_op = train_model(base_lr, radio_cls_loss*cls_loss_op + radio_bbox_loss*bbox_loss_op + radio_landmark_loss*landmark_loss_op + L2_loss_op, num)
    # init
    init = tf.global_variables_initializer()
    sess = tf.Session()
    #save model
    saver = tf.train.Saver(max_to_keep=0)
    sess.run(init)
    #visualize some variables
    tf.summary.scalar("cls_loss",cls_loss_op)#cls_loss
    tf.summary.scalar("bbox_loss",bbox_loss_op)#bbox_loss
    tf.summary.scalar("landmark_loss",landmark_loss_op)#landmark_loss
    tf.summary.scalar("cls_accuracy",accuracy_op)#cls_acc
    summary_op = tf.summary.merge_all()
    logs_dir = "../logs/%s" %(net)
    if os.path.exists(logs_dir) == False:
        os.mkdir(logs_dir)
    writer = tf.summary.FileWriter(logs_dir,sess.graph)
    #begin 
    coord = tf.train.Coordinator()
    #begin enqueue thread
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    i = 0
    #total steps
    MAX_STEP = int(num / config.BATCH_SIZE + 1) * end_epoch
    epoch = 0
    sess.graph.finalize()    
    try:
        for step in range(MAX_STEP):
            i = i + 1
            if coord.should_stop():
                break
            image_batch_array, label_batch_array, bbox_batch_array,landmark_batch_array = sess.run([image_batch, label_batch, bbox_batch,landmark_batch])
            #random flip
            image_batch_array,landmark_batch_array = random_flip_images(image_batch_array,label_batch_array,landmark_batch_array)
            '''
            print image_batch_array.shape
            print label_batch_array.shape
            print bbox_batch_array.shape
            print landmark_batch_array.shape
            print label_batch_array[0]
            print bbox_batch_array[0]
            print landmark_batch_array[0]
            '''
            _,_,summary = sess.run([train_op, lr_op ,summary_op], feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array,landmark_target:landmark_batch_array})
            
            if (step+1) % display == 0:
                #acc = accuracy(cls_pred, labels_batch)
                cls_loss, bbox_loss,landmark_loss,L2_loss,lr,acc = sess.run([cls_loss_op, bbox_loss_op,landmark_loss_op,L2_loss_op,lr_op,accuracy_op],
                                                             feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, landmark_target: landmark_batch_array})                
                print("%s : Step: %d, accuracy: %3f, cls loss: %4f, bbox loss: %4f, landmark loss: %4f,L2 loss: %4f,lr:%f " % (
                datetime.now(), step+1, acc, cls_loss, bbox_loss, landmark_loss, L2_loss, lr))
            #save every two epochs
            if i * config.BATCH_SIZE > num*2:
                epoch = epoch + 1
                i = 0
                saver.save(sess, prefix, global_step=epoch*2)
            writer.add_summary(summary,global_step=step)
    except tf.errors.OutOfRangeError:
        print("完成!!!")
    finally:
        coord.request_stop()
        writer.close()
    coord.join(threads)
    sess.close()
Beispiel #4
0
def train():
  """Train face_reg for a number of steps."""
  with tf.Graph().as_default(), tf.device('/cpu:0'):
    # Create a variable to count the number of train() calls. This equals the
    # number of batches processed * FLAGS.num_gpus.
    global_step = tf.get_variable(
        'global_step', [],
        initializer=tf.constant_initializer(0), trainable=False)

    # Calculate the learning rate schedule.
    item = './data/facescrub_train.list'    
    imagelist = open(item, 'r')
    files_item = imagelist.readlines()
    file_len = len(files_item)
    num_batches_per_epoch = (file_len /FLAGS.batch_size)
    decay_steps = int(num_batches_per_epoch * 10)
    batch_size = FLAGS.batch_size
    img_shape = 300

    # Decay the learning rate exponentially based on the number of steps.
    lr = tf.train.exponential_decay(FLAGS.learn_rate,
                                    global_step,
                                    decay_steps,
                                    0.1,
                                    staircase=True)

    # Create an optimizer that performs gradient descent.
    opt = tf.train.GradientDescentOptimizer(lr)

    # Get images and labels for CIFAR-10.
    tfrecord_file = './data/MegaFace_train.tfrecord_shuffle'
    val_file = './data/MegaFace_val.tfrecord_shuffle'
    images, labels = read_single_tfrecord(tfrecord_file, batch_size, img_shape)
    val_image_batch, val_label_batch = read_single_tfrecord(val_file, batch_size, img_shape)
    batch_queue = tf.contrib.slim.prefetch_queue.prefetch_queue(
          [images, labels], capacity=2 * FLAGS.num_gpus)
    # Calculate the gradients for each model tower.
    tower_grads = []
    with tf.variable_scope(tf.get_variable_scope()):
      for i in range(FLAGS.num_gpus):
        with tf.device('/gpu:%d' % i):
          with tf.name_scope('%s_%d' % ('tower', i)) as scope:
            # Dequeues one batch for the GPU
            image_batch, label_batch = batch_queue.dequeue()
            # Calculate the loss for one tower of the CIFAR model. This function
            # constructs the entire CIFAR model but shares the variables across
            # all towers.
            loss,center_op = tower_loss(scope, image_batch, label_batch)

            # Reuse variables for the next tower.
            tf.get_variable_scope().reuse_variables()

            # Retain the summaries from the final tower.
            summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)

            # Calculate the gradients for the batch of data on this CIFAR tower.
            with tf.control_dependencies([center_op]):
                grads = opt.compute_gradients(loss)

            # Keep track of the gradients across all towers.
            tower_grads.append(grads)

    # We must calculate the mean of each gradient. Note that this is the
    # synchronization point across all towers.
    #print("gradient shape ",tf.shape(tower_grads))
    grads = average_gradients(tower_grads)

    # Add a summary to track the learning rate.
    summaries.append(tf.summary.scalar('learning_rate', lr))

    # Add histograms for gradients.
    for grad, var in grads:
      if grad is not None:
        summaries.append(tf.summary.histogram(var.op.name + '/gradients', grad))

    # Apply the gradients to adjust the shared variables.
    apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

    # Add histograms for trainable variables.
    for var in tf.trainable_variables():
      summaries.append(tf.summary.histogram(var.op.name, var))

    # Track the moving averages of all trainable variables.
    variable_averages = tf.train.ExponentialMovingAverage(
        cifar10.MOVING_AVERAGE_DECAY, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables())

    # Group all updates to into a single train op.
    train_op = tf.group(apply_gradient_op, variables_averages_op)

    # Create a saver.
    saver = tf.train.Saver(tf.global_variables())

    # Build the summary operation from the last tower summaries.
    summary_op = tf.summary.merge(summaries)

    # Build an initialization operation to run below.
    init = tf.global_variables_initializer()

    # Start running operations on the Graph. allow_soft_placement must be set to
    # True to build towers on GPU, as some of the ops do not have GPU
    # implementations.
    sess = tf.Session(config=tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=FLAGS.log_device_placement))
    sess.run(init)

    # Start the queue runners.
    tf.train.start_queue_runners(sess=sess)

    summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

    for step in range(FLAGS.max_steps):
      start_time = time.time()
      _, loss_value = sess.run([train_op, loss])
      duration = time.time() - start_time

      assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

      if step % 10 == 0:
        num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus
        examples_per_sec = num_examples_per_step / duration
        sec_per_batch = duration / FLAGS.num_gpus

        format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                      'sec/batch)')
        print (format_str % (datetime.now(), step, loss_value,
                             examples_per_sec, sec_per_batch))

      if step % 100 == 0:
        summary_str = sess.run(summary_op)
        summary_writer.add_summary(summary_str, step)

      # Save the model checkpoint periodically.
      if step % (num_batches_per_epoch /FLAGS.num_gpus) == 0 or (step + 1) == FLAGS.max_steps:
        checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
        saver.save(sess, checkpoint_path, global_step=step)
def train(net_factory, prefix, end_epoch, base_dir,
          display=100, base_lr=0.01, with_gesture = False):

    """
    train PNet/RNet/ONet
    :param net_factory: a function defined in mtcnn_model.py
    :param prefix: model path
    :param end_epoch:
    :param dataset:
    :param display:
    :param base_lr:
    :return:

    """
    net = prefix.split('/')[-1]
    #label file
    label_file = os.path.join(base_dir,'train_%s_gesture.txt' % net)
    #label_file = os.path.join(base_dir,'gesture_12_few.txt')
    print("--------------------------------------------------------")
    print(label_file)
    f = open(label_file, 'r')
    # get number of training examples

    num = len(f.readlines())
    print("number of training examples: ", num)
    
    print("Total size of the dataset is: ", num)
    print(prefix)
    print("--------------------------------------------------------")

    #PNet use this method to get data
    if net == 'PNet':
        #dataset_dir = os.path.join(base_dir,'train_%s_ALL.tfrecord_shuffle' % net)
        dataset_dir = os.path.join(base_dir,'train_%s_gesture.tfrecord_shuffle' % net)
        print('dataset dir is:',dataset_dir)
        image_batch, label_batch, bbox_batch, gesture_batch = read_single_tfrecord(dataset_dir, config.BATCH_SIZE, net)

    #RNet use 3 tfrecords to get data    
    else:
        pos_dir = os.path.join(base_dir,'train_%s_pos_gesture.tfrecord_shuffle' % net)
        part_dir = os.path.join(base_dir,'train_%s_part_gesture.tfrecord_shuffle' % net)
        neg_dir = os.path.join(base_dir,'train_%s_neg_gesture.tfrecord_shuffle' % net)
        gesture_dir = os.path.join(base_dir,'train_%s_gesture.tfrecord_shuffle' % net)
        dataset_dirs = [pos_dir,part_dir,neg_dir,gesture_dir]
        pos_radio = 1.0/6;part_radio = 1.0/6;gesture_radio=1.0/6;neg_radio=3.0/6
        pos_batch_size = int(np.ceil(config.BATCH_SIZE*pos_radio))
        assert pos_batch_size != 0,"Batch Size Error "
        part_batch_size = int(np.ceil(config.BATCH_SIZE*part_radio))
        assert part_batch_size != 0,"Batch Size Error "
        neg_batch_size = int(np.ceil(config.BATCH_SIZE*neg_radio))
        assert neg_batch_size != 0,"Batch Size Error "
        gesture_batch_size = int(np.ceil(config.BATCH_SIZE*gesture_radio))
        assert gesture_batch_size != 0,"Batch Size Error "
        if (pos_batch_size+part_batch_size+neg_batch_size+gesture_batch_size)>config.BATCH_SIZE:
            gesture_batch_size = config.BATCH_SIZE - pos_batch_size - part_batch_size - neg_batch_size 
        assert pos_batch_size + part_batch_size + neg_batch_size + gesture_batch_size == config.BATCH_SIZE, "num exceeded batchsize"
        batch_sizes = [pos_batch_size,part_batch_size,neg_batch_size,gesture_batch_size]
        #print('batch_size is:', batch_sizes)
        image_batch, label_batch, bbox_batch, gesture_batch = read_multi_tfrecords(dataset_dirs,batch_sizes, net)        
        
        # if we make sure that no need for multiple tfrecords we can remove the if-else statement
        # dataset_dir = os.path.join(base_dir,'train_%s_gesture.tfrecord_shuffle' % net)
        #print('dataset dir is:',dataset_dir)
    
    # image_batch, label_batch, bbox_batch, gesture_batch = read_single_tfrecord(dataset_dir, config.BATCH_SIZE, net)
        # 
    #gesture_dir    
    if net == 'PNet':
        image_size = 12
        radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_gesture_loss = 0.5
    elif net == 'RNet':
        image_size = 24
        radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_gesture_loss = 0.5
    else:
        image_size = 48
        display = 50
        radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_gesture_loss = 1
        
    #define placeholders
    input_image = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, image_size, image_size, 3], name='input_image')
    label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label')
    bbox_target = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, 4], name='bbox_target')
    gesture_target = tf.placeholder(tf.float32,shape=[config.BATCH_SIZE,3],name='gesture_target')
    #get loss and accuracy
    # print(bbox_target)
    # print(gesture_target)
    input_image = image_color_distort(input_image)
    cls_loss_op,bbox_loss_op,gesture_loss_op,L2_loss_op,accuracy_op = net_factory(input_image, label, bbox_target,gesture_target,training=True)
    if with_gesture:
        total_loss_op  = radio_cls_loss*cls_loss_op + radio_bbox_loss*bbox_loss_op + radio_gesture_loss*gesture_loss_op + L2_loss_op
    else:
        total_loss_op  = radio_cls_loss*cls_loss_op + radio_bbox_loss*bbox_loss_op + L2_loss_op
    #train,update learning rate(3 loss)
    
    train_op, lr_op = train_model(base_lr,
                                  total_loss_op,
                                  num)
    # init
    init = tf.global_variables_initializer()
    sess = tf.Session()

    #save model
    saver = tf.train.Saver(max_to_keep=1)
    sess.run(init)

    #visualize some variables
    tf.summary.scalar("cls_accuracy",accuracy_op)#cls_acc
    tf.summary.scalar("cls_loss",cls_loss_op)#cls_loss
    tf.summary.scalar("bbox_loss",bbox_loss_op)#bbox_loss
    if with_gesture:
        tf.summary.scalar("gesture_loss",gesture_loss_op)#gesture_loss
    tf.summary.scalar("total_loss",total_loss_op)#cls_loss, bbox loss, gesture loss and L2 loss add together
    tf.summary.scalar("learn_rate",lr_op)#logging learning rate
    summary_op = tf.summary.merge_all()

    time = 'train-{}-{date:%Y-%m-%d_%H:%M:%S}'.format(net, date=datetime.now() )
    print("-------------------------------------------------------------\n")
    print("the configuration is as follows:")
    print("base_lr: {}  lr_factor: {}  end_epoch: {}".format(base_dir, LR_FACTOR, end_epoch))
    print("the sub dir's name is: ", time)
    print("-------------------------------------------------------------\n")
    logs_dir = "../logs/%s/" %(net)
    logs_dir = logs_dir + time + "/"
    if os.path.exists(logs_dir) == False:
        os.makedirs(logs_dir)
    writer = tf.summary.FileWriter(logs_dir,sess.graph)
    projector_config = projector.ProjectorConfig()
    projector.visualize_embeddings(writer,projector_config)
    #begin 
    coord = tf.train.Coordinator()
    #begin enqueue thread
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    i = 0
    #total steps
    MAX_STEP = int(num / config.BATCH_SIZE + 1) * end_epoch
    print("max_step: ", MAX_STEP)
    epoch = 0
    sess.graph.finalize()

    try:

        for step in range(MAX_STEP):
            i = i + 1
            if coord.should_stop():
                break
            image_batch_array, label_batch_array, bbox_batch_array,gesture_batch_array = sess.run([image_batch, label_batch, bbox_batch,gesture_batch])
            # image_batch_array, label_batch_array, bbox_batch_array = sess.run([image_batch, label_batch, bbox_batch])
            #random flip
            image_batch_array, gesture_batch_array = random_flip_images(image_batch_array,label_batch_array, gesture_batch_array)
            # image_batch_array, _ = random_flip_images(image_batch_array,label_batch_array)
            '''
            print('im here')
            print(image_batch_array.shape)
            print(label_batch_array.shape)
            print(bbox_batch_array.shape)
            
            print(label_batch_array[0])
            print(bbox_batch_array[0])
            print(gesture_batch_array[0])
            '''
            # print(gesture_batch_array.shape)
            _,_,summary = sess.run([train_op, lr_op ,summary_op], feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array,gesture_target:gesture_batch_array})
            # _,_,summary = sess.run([train_op, lr_op ,summary_op], feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array})

            if (step+1) % display == 0:
                #acc = accuracy(cls_pred, labels_batch)
                if with_gesture:
                    cls_loss, bbox_loss,gesture_loss,L2_loss,lr,acc = sess.run([cls_loss_op, bbox_loss_op,gesture_loss_op,L2_loss_op,lr_op,accuracy_op],
                                                                 feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, gesture_target: gesture_batch_array})
                    total_loss = radio_cls_loss*cls_loss + radio_bbox_loss*bbox_loss + radio_gesture_loss*gesture_loss + L2_loss
                    print("%s : Step: %d/%d, accuracy: %3f, cls loss: %4f, bbox loss: %4f,gesture loss :%4f,regularisation loss: %4f, Total Loss: %4f ,lr:%f " % (
                    datetime.now(), step+1,MAX_STEP, acc, cls_loss, bbox_loss,gesture_loss, L2_loss,total_loss, lr))
                else: # without gesture loss
                    cls_loss, bbox_loss,L2_loss,lr,acc = sess.run([cls_loss_op, bbox_loss_op,L2_loss_op,lr_op,accuracy_op],
                                                     feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, gesture_target: gesture_batch_array})
                    total_loss = radio_cls_loss*cls_loss + radio_bbox_loss*bbox_loss + L2_loss
                    print("%s : Step: %d/%d, accuracy: %3f, cls loss: %4f, bbox loss: %4f,regularisation loss: %4f, Total Loss: %4f ,lr:%f " % (
                    datetime.now(), step+1,MAX_STEP, acc, cls_loss, bbox_loss, L2_loss,total_loss, lr))

            #save every two epochs
            if i * config.BATCH_SIZE > num*2:
                epoch = epoch + 1
                i = 0
                path_prefix = saver.save(sess, prefix, global_step=epoch*2)
                print('path prefix is :', path_prefix)
            writer.add_summary(summary,global_step=step)

    except tf.errors.OutOfRangeError:
        print("Finished!")
    finally:
        coord.request_stop()
        writer.close()
    coord.join(threads)
    sess.close()
def test(net_factory, prefix, base_dir, display=100, batchsize = 1):

    """
    testing: batch size = 1
    :param net_factory: P/R/ONet
    :param base_dir: tfrecord path
    :param prefix: model path
    :param display:
    :param lr: learning rate
    :return:

    """
    net = prefix.split('/')[-1]
    #label file
    label_file = os.path.join(base_dir,'test_%s_gesture.txt' % net)
    #label_file = os.path.join(base_dir,'gesture_12_few.txt')
    print(label_file)
    f = open(label_file, 'r')
    # get number of testing examples
    lines = f.readlines()
    num = len(lines)
    if lines[0] != ".":
        num -= 1
    print("Total size of the dataset is: ", num)
    print("The prefix is: ", prefix)

    #PNet use this method to get data
    #if net == 'PNet':
        #dataset_dir = os.path.join(base_dir,'train_%s_ALL.tfrecord_shuffle' % net)
    dataset_dir = os.path.join(base_dir,'test_%s_gesture.tfrecord_shuffle' % net)
    print('dataset dir is:',dataset_dir)
    image_batch, label_batch, bbox_batch, gesture_batch = read_single_tfrecord(dataset_dir, batchsize, net)
    image_size = 12
    # radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_gesture_loss = 0.5
        
    # else 之后再写吧lol:need to use multi_tfrecord reader
    """ for RNET & ONET """
    #else

    
    

    #set placeholders first 
    #change batchsize to 1 for testing
    
    input_image = tf.placeholder(tf.float32, shape=[batchsize, image_size, image_size, 3], name='input_image')
    label = tf.placeholder(tf.float32, shape=[batchsize], name='label')
    bbox_target = tf.placeholder(tf.float32, shape=[batchsize, 4], name='bbox_target')
    gesture_target = tf.placeholder(tf.float32,shape=[batchsize,3],name='gesture_target')

    input_image = image_color_distort(input_image)
    cls_pro, bbox_pred, gesture_pred = net_factory(input_image, training=False)

    """
    cls_loss_op,bbox_loss_op,gesture_loss_op,L2_loss_op,accuracy_op = net_factory(input_image, label, bbox_target,gesture_target,training=False)
    #train,update learning rate(3 loss)
    total_loss_op  = radio_cls_loss*cls_loss_op + radio_bbox_loss*bbox_loss_op + radio_gesture_loss*gesture_loss_op + L2_loss_op
    # base_lr = 0
    # train_op, lr_op = train_model(base_lr,
    #                               total_loss_op,
    #                               num) #for testing, set base lr to 0
    """

    # here calculate the corresponding acc and loss
    
    accuracy_op = cal_acc(cls_pro,label)
    cls_loss_op = cls_cal_loss(cls_pro,label)
    bbox_loss_op = bbox_cal_loss(bbox_pred, bbox_target, label)
    gesture_loss_op = gesture_cal_loss(gesture_pred,gesture_target,label)
    # L2_loss_op = tf.add_n(slim.losses.get_regularization_losses())
    # total_loss_op  = radio_cls_loss*cls_loss_op + radio_bbox_loss*bbox_loss_op + radio_gesture_loss*gesture_loss_op + L2_loss_op

    # init
    init = tf.global_variables_initializer()
    sess = tf.Session()

    # #save model //no need since testing
    # saver = tf.train.Saver(max_to_keep=0)
    sess.run(init)

    #visualize some variables
    tf.summary.scalar("accuracy", accuracy_op)
    tf.summary.scalar("cls_loss",cls_loss_op)
    tf.summary.scalar("bbox_loss",bbox_loss_op)
    tf.summary.scalar("gesture_loss",gesture_loss_op)
    # tf.summary.scalar("total_loss", total_loss_op)
    summary_op = tf.summary.merge_all()

    time = 'test-{date:%Y-%m-%d_%H:%M:%S}'.format( date=datetime.now() )
    print("---------------------------------------------------------------------------\n")
    print("current time: ", time)
    print("---------------------------------------------------------------------------\n")
    print("the sub dir's name is: ", time)
    print("---------------------------------------------------------------------------\n")
    logs_dir = "../logs/%s/" %(net)
    logs_dir = logs_dir + time + "/"
    if os.path.exists(logs_dir) == False:
        os.makedirs(logs_dir)

    writer = tf.summary.FileWriter(logs_dir,sess.graph)
    projector_config = projector.ProjectorConfig()
    projector.visualize_embeddings(writer,projector_config)
    #begin 
    coord = tf.train.Coordinator()
    #begin enqueue thread
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    i = 0
    #total steps
    MAX_STEP = int(num / batchsize + 1) * 1 #change config.BATCHSIZE and end_epoch to 1, max_step = num+1
    epoch = 0
    sess.graph.finalize()
    
    # setting list for final evaluation 
    acc_list = []
    cls_loss_list = []
    bbox_loss_list = []
    gesture_loss_list = []

    try:
        for step in range(MAX_STEP):
            i = i + 1
            if coord.should_stop():
                break
            image_batch_array, label_batch_array, bbox_batch_array,gesture_batch_array = sess.run([image_batch, label_batch, bbox_batch,gesture_batch])
            #random flip
            image_batch_array,gesture_batch_array = random_flip_images(image_batch_array,label_batch_array,gesture_batch_array)

            summary = sess.run([summary_op], feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array,gesture_target:gesture_batch_array})
            summary = summary[0]

            # write statements here to control the operations
            pos_label = 1
            neg_label = 0
            part_label = -1
            gesture_label = -2
            if (pos_label in label_batch_array) or (neg_label in label_batch_array):
                cls_loss, accuracy = sess.run([cls_loss_op, accuracy_op],
                                                 feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, gesture_target: gesture_batch_array})
                if cls_loss != None:
                    cls_loss_list.append(cls_loss)
                if accuracy != None:
                    acc_list.append(accuracy)
            elif (pos_label in label_batch_array) or (part_label in label_batch_array):
                bbox_loss = sess.run([bbox_loss_op], feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, gesture_target: gesture_batch_array})
                if bbox_loss != None:
                    bbox_loss_list.append(bbox_loss)
            elif gesture_label in label_batch_array:
                gesture_loss = sess.run([gesture_loss_op], feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, gesture_target: gesture_batch_array})
                if gesture_loss != None:
                    gesture_loss_list.append(gesture_loss)

            if (step+1) % display == 0:
                
                cls_loss,bbox_loss,gesture_loss, accuracy = sess.run([cls_loss_op,bbox_loss_op,gesture_loss_op,accuracy_op],
                                                             feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, gesture_target: gesture_batch_array})

                # total_loss = radio_cls_loss*cls_loss + radio_bbox_loss*bbox_loss + radio_gesture_loss*gesture_loss + L2_loss
                # gesture loss: %4f,
                print("%s : Step: %d/%d, accuracy: %3f, cls loss: %4f, bbox loss: %4f, gesture_loss: %4f  " % (datetime.now(), step+1,MAX_STEP, accuracy, cls_loss,bbox_loss,gesture_loss))


            #save every epoch #(was every two epochs) 
            if i * batchsize > num: #change config.BATCHSIZE to 1, num was num*2
                epoch = epoch + 1
                i = 0
                # path_prefix = saver.save(sess, prefix, global_step=epoch*2)
                # print('path prefix is :', path_prefix)
            
            writer.add_summary(summary,global_step=step)

        acc_list = np.array(acc_list)
        cls_loss_list = np.array(cls_loss_list)
        bbox_loss_list = np.array(bbox_loss_list)
        gesture_loss_list = np.array(gesture_loss_list)
        mean_acc = np.mean(acc_list)
        mean_cls_loss = np.mean(cls_loss_list)
        mean_bbox_loss = np.mean(bbox_loss_list)
        mean_gesture_loss = np.mean(gesture_loss_list)
        print("-------------------------------summary-------------------------------")
        print("mean cls accuracy: ", mean_acc)
        print("mean cls loss: ", mean_cls_loss)
        print("mean bbox loss: ", mean_bbox_loss)
        print("mean gesture loss: ", mean_gesture_loss)
        print("---------------------------------------------------------------------")
        

    except tf.errors.OutOfRangeError:
        print("Finished!")
    finally:
        coord.request_stop()
        writer.close()

    coord.join(threads)
    sess.close()
Beispiel #7
0
def train(net_factory,
          prefix,
          load_epoch,
          end_epoch,
          base_dir,
          display=200,
          base_lr=0.01,
          gpu_ctx='/device:GPU:0'):
    """
    train PNet/RNet/ONet
    :param net_factory:
    :param prefix:
    :param end_epoch:16
    :param dataset:
    :param display:
    :param base_lr:
    :return:
    """
    net = prefix.split('/')[-1]
    #label file
    label_file = os.path.join(base_dir, 'train_%s_landmark.txt' % net)
    #label_file = os.path.join(base_dir,'landmark_12_few.txt')
    #print label_file
    f = open(label_file, 'r')
    num = len(f.readlines())
    print("Total datasets is: ", num)
    print("saved prefix: ", prefix)

    #PNet use this method to get data
    if net == 'PNet':
        #dataset_dir = os.path.join(base_dir,'train_%s_ALL.tfrecord_shuffle' % net)
        dataset_dir = os.path.join(base_dir,
                                   'train_%s_landmark.tfrecord_shuffle' % net)
        print("dataset saved dir: ", dataset_dir)
        image_batch, label_batch, bbox_batch, landmark_batch = read_single_tfrecord(
            dataset_dir, config.BATCH_SIZE, net)

    #RNet use 3 tfrecords to get data
    else:
        pos_dir = os.path.join(base_dir, 'pos_landmark.tfrecord_shuffle')
        part_dir = os.path.join(base_dir, 'part_landmark.tfrecord_shuffle')
        neg_dir = os.path.join(base_dir, 'neg_landmark.tfrecord_shuffle')
        if train_face:
            landmark_dir = os.path.join(base_dir,
                                        'landmark_landmark.tfrecord_shuffle')
        else:
            landmark_dir = None
        dataset_dirs = [pos_dir, part_dir, neg_dir, landmark_dir]
        if train_face:
            pos_radio = 1.0 / 6
            part_radio = 1.0 / 6
            landmark_radio = 1.0 / 6
            neg_radio = 3.0 / 6
        else:
            pos_radio = 2.0 / 3
            part_radio = 1.0 / 6
            landmark_radio = 0
        pos_batch_size = int(np.ceil(config.BATCH_SIZE * pos_radio))
        assert pos_batch_size != 0, "Batch Size Error "
        part_batch_size = int(np.ceil(config.BATCH_SIZE * part_radio))
        assert part_batch_size != 0, "Batch Size Error "
        if train_face:
            neg_batch_size = int(np.ceil(config.BATCH_SIZE * neg_radio))
            assert neg_batch_size != 0, "Batch Size Error "
            #landmark_batch_size = int(np.ceil(config.BATCH_SIZE*landmark_radio))
            landmark_batch_size = int(config.BATCH_SIZE - pos_batch_size -
                                      part_batch_size - neg_batch_size)
            assert landmark_batch_size != 0, "Batch Size Error "
            batch_sizes = [
                pos_batch_size, part_batch_size, neg_batch_size,
                landmark_batch_size
            ]
            image_batch, label_batch, bbox_batch, landmark_batch = read_multi_tfrecords(
                dataset_dirs, batch_sizes, net)
        else:
            landmark_batch_size = 1
            neg_batch_size = int(config.BATCH_SIZE - pos_batch_size -
                                 part_batch_size)
            assert neg_batch_size != 0, "Batch Size Error "
            batch_sizes = [
                pos_batch_size, part_batch_size, neg_batch_size,
                landmark_batch_size
            ]
            image_batch, label_batch, bbox_batch = read_multi_tfrecords(
                dataset_dirs, batch_sizes, net)

    #landmark_dir
    if net == 'PNet':
        image_size = 12
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 0.5
    elif net == 'RNet':
        image_size = 24
        #radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 0.5
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 1.0
    else:
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 1.0
        image_size = 48

    #define placeholder
    input_image = tf.placeholder(
        tf.float32,
        shape=[config.BATCH_SIZE, image_size, image_size, 3],
        name='input_image')
    label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label')
    bbox_target = tf.placeholder(tf.float32,
                                 shape=[config.BATCH_SIZE, 4],
                                 name='bbox_target')
    landmark_target = tf.placeholder(tf.float32,
                                     shape=[config.BATCH_SIZE, 10],
                                     name='landmark_target')
    #class,regression
    #with tf.device(gpu_ctx):
    if train_face:
        cls_loss_op, bbox_loss_op, landmark_loss_op, L2_loss_op, accuracy_op = net_factory(
            input_image, label, bbox_target, landmark_target, training=True)
    else:
        cls_loss_op, bbox_loss_op, landmark_loss_op, L2_loss_op, accuracy_op = net_factory(
            input_image, label, bbox_target, training=True)
    #train,update learning rate(3 loss)
    train_op, lr_op = train_model(
        base_lr,
        radio_cls_loss * cls_loss_op + radio_bbox_loss * bbox_loss_op +
        radio_landmark_loss * landmark_loss_op + L2_loss_op, num)
    # init
    #init = tf.global_variables_initializer()
    #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.7)
    #tf_config.gpu_options = gpu_options
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    tf_config.log_device_placement = False
    sess = tf.Session(config=tf_config)
    #sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    #save model
    saver = tf.train.Saver(max_to_keep=0)
    #sess.run(init)
    #load pretrained parameters
    if load_epoch:
        #check whether the dictionary is valid
        model_path = "%s-%s" % (prefix, str(load_epoch))
        model_dict = '/'.join(model_path.split('/')[:-1])
        ckpt = tf.train.get_checkpoint_state(model_dict)
        print("restore model path:", model_path)
        readstate = ckpt and ckpt.model_checkpoint_path
        #assert  readstate, "the params dictionary is not valid"
        saver.restore(sess, model_path)
        print("restore models' param")
    else:
        init = tf.global_variables_initializer()
        sess.run(init)
        print("init models using gloable: init")

    #visualize some variables
    tf.summary.scalar("cls_loss", cls_loss_op)  #cls_loss
    tf.summary.scalar("bbox_loss", bbox_loss_op)  #bbox_loss
    tf.summary.scalar("landmark_loss", landmark_loss_op)  #landmark_loss
    tf.summary.scalar("cls_accuracy", accuracy_op)  #cls_acc
    summary_op = tf.summary.merge_all()
    logs_dir = "../logs/%s" % (net)
    if os.path.exists(logs_dir) == False:
        os.mkdir(logs_dir)
    writer = tf.summary.FileWriter(logs_dir, sess.graph)
    #begin
    coord = tf.train.Coordinator()
    #begin enqueue thread
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    i = 0
    #total steps
    MAX_STEP = int(num / config.BATCH_SIZE + 1) * end_epoch
    epoch = 0
    sess.graph.finalize()
    save_acc = 0
    L2_loss = 0
    model_dict = '/'.join(prefix.split('/')[:-1])
    log_r_file = os.path.join(model_dict, "train_record.txt")
    print("model record is ", log_r_file)
    record_file_out = open(log_r_file, 'w')
    #record_file_out = open("train_record.txt",'w')
    try:
        for step in range(MAX_STEP):
            i = i + 1
            if coord.should_stop():
                break
            if train_face:
                if net == 'PNet':
                    image_batch_array, label_batch_array, bbox_batch_array = sess.run(
                        [image_batch, label_batch, bbox_batch])
                else:
                    image_batch_array, label_batch_array, bbox_batch_array, landmark_batch_array = sess.run(
                        [image_batch, label_batch, bbox_batch, landmark_batch])
            else:
                image_batch_array, label_batch_array, bbox_batch_array = sess.run(
                    [image_batch, label_batch, bbox_batch])
            #print("shape:  ",i, np.shape(image_batch_array),np.shape(label_batch_array),np.shape(bbox_batch_array))
            #random flip
            if train_face and not (net == 'PNet'):
                image_batch_array, landmark_batch_array = random_flip_images(
                    image_batch_array, label_batch_array, landmark_batch_array)
            '''
            print image_batch_array.shape
            print label_batch_array.shape
            print bbox_batch_array.shape
            print landmark_batch_array.shape
            print label_batch_array[0]
            print bbox_batch_array[0]
            print landmark_batch_array[0]
            '''
            if train_face and not (net == 'PNet'):
                _, _, summary = sess.run(
                    [train_op, lr_op, summary_op],
                    feed_dict={
                        input_image: image_batch_array,
                        label: label_batch_array,
                        bbox_target: bbox_batch_array,
                        landmark_target: landmark_batch_array
                    })
            else:
                _, _, summary = sess.run(
                    [train_op, lr_op, summary_op],
                    feed_dict={
                        input_image: image_batch_array,
                        label: label_batch_array,
                        bbox_target: bbox_batch_array
                    })

            if (step + 1) % display == 0:
                #acc = accuracy(cls_pred, labels_batch)
                if train_face and not (net == 'PNet'):
                    cls_loss, bbox_loss,landmark_loss,L2_loss,lr,acc = sess.run([cls_loss_op, bbox_loss_op,landmark_loss_op,L2_loss_op,lr_op,accuracy_op],\
                                                             feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, landmark_target: landmark_batch_array})
                    print(
                        "%s : Step: %d, accuracy: %3f, cls loss: %4f, bbox loss: %4f, landmark loss: %4f,L2 loss: %4f,lr:%f "
                        % (datetime.now(), step + 1, acc, cls_loss, bbox_loss,
                           landmark_loss, L2_loss, lr))
                else:
                    cls_loss, bbox_loss,L2_loss,lr,acc = sess.run([cls_loss_op, bbox_loss_op,L2_loss_op,lr_op,accuracy_op],\
                                                             feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array})
                    print(
                        "%s : Step: %d, accuracy: %3f, cls loss: %4f, bbox loss: %4f, L2 loss: %4f,lr:%f "
                        % (datetime.now(), step + 1, acc, cls_loss, bbox_loss,
                           L2_loss, lr))
            #save every two epochs
            if i * config.BATCH_SIZE > num * 10:
                epoch = epoch + 1
                i = 0
                #if save_acc < L2_loss:
                saver.save(sess, prefix, global_step=epoch * 100)
                save_acc = L2_loss
                if train_face and not (net == 'PNet'):
                    print(
                        "%s : Step: %d, accuracy: %3f, cls loss: %4f, bbox loss: %4f, landmark loss: %4f,L2 loss: %4f,lr:%f "
                        % (datetime.now(), step + 1, acc, cls_loss, bbox_loss,
                           landmark_loss, L2_loss, lr))
                    record_file_out.write(
                        "%s : epoch: %d, accuracy: %3f, cls loss: %4f, bbox loss: %4f, landmark_loss: %4f,lr:%f \n"
                        % (datetime.now(), epoch * 100, acc, cls_loss,
                           bbox_loss, landmark_loss, lr))
                else:
                    print(
                        "%s : Step: %d, accuracy: %3f, cls loss: %4f, bbox loss: %4f, L2 loss: %4f,lr:%f "
                        % (datetime.now(), step + 1, acc, cls_loss, bbox_loss,
                           L2_loss, lr))
                    record_file_out.write(
                        "%s : epoch: %d, accuracy: %3f, cls loss: %4f, bbox loss: %4f, L2_loss: %4f,lr:%f \n"
                        % (datetime.now(), epoch * 100, acc, cls_loss,
                           bbox_loss, L2_loss, lr))
                print("model saved over ", save_acc)
            writer.add_summary(summary, global_step=step)
    except tf.errors.OutOfRangeError:
        print("Over!!!")
    finally:
        coord.request_stop()
        writer.close()
    coord.join(threads)
    record_file_out.close()
    sess.close()
Beispiel #8
0
def train(net_factory, prefix, end_epoch, base_dir, display=200, base_lr=0.01):
    """
    train PNet/RNet/ONet
    :param net_factory:
    :param prefix: model path
    :param end_epoch:
    :param dataset:
    :param display:
    :param base_lr:
    :return:
    """
    before = 16
    net = prefix.split('/')[-1]
    #label file
    label_file = os.path.join(base_dir, 'train_%s_landmark.txt' % net)
    #label_file = os.path.join(base_dir,'landmark_12_few.txt')
    print(label_file)
    f = open(label_file, 'r')
    # get number of training examples
    num = len(f.readlines())
    print("Total size of the dataset is: ", num)
    print(prefix)

    #PNet use this method to get data
    if net == 'PNet':
        #dataset_dir = os.path.join(base_dir,'train_%s_ALL.tfrecord_shuffle' % net)
        dataset_dir = os.path.join(base_dir,
                                   'train_%s_landmark.tfrecord_shuffle' % net)
        print('dataset dir is:', dataset_dir)
        image_batch, label_batch, bbox_batch, landmark_batch = read_single_tfrecord(
            dataset_dir, config.BATCH_SIZE, net)

    #RNet use 3 tfrecords to get data
    else:
        pos_dir = os.path.join(base_dir, 'pos_landmark.tfrecord_shuffle')
        part_dir = os.path.join(base_dir, 'part_landmark.tfrecord_shuffle')
        neg_dir = os.path.join(base_dir, 'neg_landmark.tfrecord_shuffle')
        #landmark_dir = os.path.join(base_dir,'landmark_landmark.tfrecord_shuffle')
        landmark_dir = os.path.join(base_dir,
                                    'landmark_landmark.tfrecord_shuffle')
        dataset_dirs = [pos_dir, part_dir, neg_dir, landmark_dir]
        pos_radio = 1.0 / 6
        part_radio = 1.0 / 6
        landmark_radio = 1.0 / 6
        neg_radio = 3.0 / 6
        pos_batch_size = int(np.ceil(config.BATCH_SIZE * pos_radio))
        assert pos_batch_size != 0, "Batch Size Error "
        part_batch_size = int(np.ceil(config.BATCH_SIZE * part_radio))
        assert part_batch_size != 0, "Batch Size Error "
        neg_batch_size = int(np.ceil(config.BATCH_SIZE * neg_radio))
        assert neg_batch_size != 0, "Batch Size Error "
        landmark_batch_size = int(np.ceil(config.BATCH_SIZE * landmark_radio))
        assert landmark_batch_size != 0, "Batch Size Error "
        batch_sizes = [
            pos_batch_size, part_batch_size, neg_batch_size,
            landmark_batch_size
        ]
        #print('batch_size is:', batch_sizes)
        image_batch, label_batch, bbox_batch, landmark_batch = read_multi_tfrecords(
            dataset_dirs, batch_sizes, net)

    #landmark_dir
    if net == 'PNet':
        image_size = 12
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 0.5
    elif net == 'RNet':
        image_size = 24
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 0.5
    else:
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 1
        image_size = 48

    #define placeholder
    input_image = tf.placeholder(
        tf.float32,
        shape=[config.BATCH_SIZE, image_size, image_size, 3],
        name='input_image')
    label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label')
    bbox_target = tf.placeholder(tf.float32,
                                 shape=[config.BATCH_SIZE, 4],
                                 name='bbox_target')
    landmark_target = tf.placeholder(tf.float32,
                                     shape=[config.BATCH_SIZE, 10],
                                     name='landmark_target')
    #get loss and accuracy
    input_image = image_color_distort(input_image)
    cls_loss_op, bbox_loss_op, landmark_loss_op, L2_loss_op, accuracy_op = net_factory(
        input_image, label, bbox_target, landmark_target, training=True)
    #train,update learning rate(3 loss)
    total_loss_op = radio_cls_loss * cls_loss_op + radio_bbox_loss * bbox_loss_op + radio_landmark_loss * landmark_loss_op + L2_loss_op
    train_op, lr_op = train_model(base_lr, total_loss_op, num)
    # init
    #init = tf.global_variables_initializer()
    sess = tf.Session()

    #save model
    saver = tf.train.Saver(max_to_keep=0)
    #sess.run(init)
    saver.restore(sess, prefix + '-%d' % before)
    #visualize some variables
    tf.summary.scalar("cls_loss", cls_loss_op)  #cls_loss
    tf.summary.scalar("bbox_loss", bbox_loss_op)  #bbox_loss
    tf.summary.scalar("landmark_loss", landmark_loss_op)  #landmark_loss
    tf.summary.scalar("cls_accuracy", accuracy_op)  #cls_acc
    tf.summary.scalar(
        "total_loss", total_loss_op
    )  #cls_loss, bbox loss, landmark loss and L2 loss add together
    summary_op = tf.summary.merge_all()
    logs_dir = "../logs/%s" % (net)
    if os.path.exists(logs_dir) == False:
        os.mkdir(logs_dir)
    writer = tf.summary.FileWriter(logs_dir, sess.graph)
    projector_config = projector.ProjectorConfig()
    projector.visualize_embeddings(writer, projector_config)
    #begin
    coord = tf.train.Coordinator()
    #begin enqueue thread
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    i = 0
    before_step = 67800
    #total steps
    MAX_STEP = int(num / config.BATCH_SIZE + 1) * (end_epoch - before)
    epoch = 16
    sess.graph.finalize()
    try:

        for step in range(before_step, before_step + MAX_STEP):
            i = i + 1
            if coord.should_stop():
                break
            image_batch_array, label_batch_array, bbox_batch_array, landmark_batch_array = sess.run(
                [image_batch, label_batch, bbox_batch, landmark_batch])
            #random flip
            image_batch_array, landmark_batch_array = random_flip_images(
                image_batch_array, label_batch_array, landmark_batch_array)
            '''
            print('im here')
            print(image_batch_array.shape)
            print(label_batch_array.shape)
            print(bbox_batch_array.shape)
            print(landmark_batch_array.shape)
            print(label_batch_array[0])
            print(bbox_batch_array[0])
            print(landmark_batch_array[0])
            '''

            _, _, summary = sess.run(
                [train_op, lr_op, summary_op],
                feed_dict={
                    input_image: image_batch_array,
                    label: label_batch_array,
                    bbox_target: bbox_batch_array,
                    landmark_target: landmark_batch_array
                })

            if (step + 1) % display == 0:
                #acc = accuracy(cls_pred, labels_batch)
                cls_loss, bbox_loss, landmark_loss, L2_loss, lr, acc = sess.run(
                    [
                        cls_loss_op, bbox_loss_op, landmark_loss_op,
                        L2_loss_op, lr_op, accuracy_op
                    ],
                    feed_dict={
                        input_image: image_batch_array,
                        label: label_batch_array,
                        bbox_target: bbox_batch_array,
                        landmark_target: landmark_batch_array
                    })

                total_loss = radio_cls_loss * cls_loss + radio_bbox_loss * bbox_loss + radio_landmark_loss * landmark_loss + L2_loss
                # landmark loss: %4f,
                eval_result = "%s : Step: %d/%d, accuracy: %3f, cls loss: %4f, bbox loss: %4f,Landmark loss :%4f,L2 loss: %4f, Total Loss: %4f ,lr:%f " % (
                    datetime.now(), step + 1, MAX_STEP + before_step, acc,
                    cls_loss, bbox_loss, landmark_loss, L2_loss, total_loss,
                    lr)
                print(eval_result + '\n')
                with open(EVALRESULTFILE, 'a') as f:
                    f.write(eval_result + '\n')

            #save every two epochs
            if i * config.BATCH_SIZE > num:
                epoch = epoch + 1
                i = 0
                path_prefix = saver.save(sess, prefix, global_step=epoch)
                print('path prefix is :', path_prefix)
                with open(EVALRESULTFILE, 'a') as f:
                    f.write('path prefix is :' + path_prefix + '\n')
            writer.add_summary(summary, global_step=step)
    except tf.errors.OutOfRangeError:
        print("完成!!!")
    finally:
        coord.request_stop()
        writer.close()
    coord.join(threads)
    sess.close()
Beispiel #9
0
def test(net_factory, prefix, end_epoch, base_dir, display=100):
    """
    testing: batch size = 1
    :param net_factory: P/R/ONet
    :param base_dir: tfrecord path
    :param prefix: model path
    :param display:
    :param lr: learning rate
    :return:

    """
    net = prefix.split('/')[-1]
    #label file
    label_file = os.path.join(base_dir, 'test_%s_gesture.txt' % net)
    #label_file = os.path.join(base_dir,'gesture_12_few.txt')
    print(label_file)
    f = open(label_file, 'r')
    # get number of testing examples
    num = len(f.readlines())
    print("Total size of the dataset is: ", num)
    print(prefix)

    #PNet use this method to get data
    #if net == 'PNet':
    #dataset_dir = os.path.join(base_dir,'train_%s_ALL.tfrecord_shuffle' % net)
    dataset_dir = os.path.join(base_dir,
                               'test_%s_gesture.tfrecord_shuffle' % net)
    print('dataset dir is:', dataset_dir)
    image_batch, label_batch, bbox_batch, gesture_batch = read_single_tfrecord(
        dataset_dir, 1, net)
    image_size = 12
    radio_cls_loss = 1.0
    radio_bbox_loss = 0.5
    radio_gesture_loss = 0.5

    # else 之后再写吧lol:need to use multi_tfrecord reader
    """ for RNET & ONET """
    #else:

    #set placeholders first
    #change batchsize to 1 for testing
    input_image = tf.placeholder(tf.float32,
                                 shape=[1, image_size, image_size, 3],
                                 name='input_image')
    label = tf.placeholder(tf.float32, shape=[1], name='label')
    bbox_target = tf.placeholder(tf.float32, shape=[1, 4], name='bbox_target')
    gesture_target = tf.placeholder(tf.float32,
                                    shape=[1, 3],
                                    name='gesture_target')

    input_image = image_color_distort(input_image)

    cls_loss_op, bbox_loss_op, gesture_loss_op, L2_loss_op, accuracy_op = net_factory(
        input_image, label, bbox_target, gesture_target, training=False)
    #train,update learning rate(3 loss)
    total_loss_op = radio_cls_loss * cls_loss_op + radio_bbox_loss * bbox_loss_op + radio_gesture_loss * gesture_loss_op + L2_loss_op
    base_lr = 0
    train_op, lr_op = train_model(base_lr, total_loss_op,
                                  num)  #for testing, set base lr to 0

    #cls_pro_test,bbox_pred_test,gesture_pred_test = net_factory(input_image, label, bbox_target,gesture_target,training=False)

    # init
    init = tf.global_variables_initializer()
    sess = tf.Session()

    #save model
    saver = tf.train.Saver(max_to_keep=0)
    sess.run(init)

    #visualize some variables
    tf.summary.scalar("cls_loss", cls_loss_op)
    tf.summary.scalar("bbox_loss", bbox_loss_op)
    tf.summary.scalar("gesture_loss", gesture_loss_op)

    summary_op = tf.summary.merge_all()
    logs_dir = "../logs_testing/%s" % (net)
    if os.path.exists(logs_dir) == False:
        os.makedirs(logs_dir)
    writer = tf.summary.FileWriter(logs_dir, sess.graph)
    projector_config = projector.ProjectorConfig()
    projector.visualize_embeddings(writer, projector_config)
    #begin
    coord = tf.train.Coordinator()
    #begin enqueue thread
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    i = 0
    #total steps
    MAX_STEP = int(num / config.BATCH_SIZE + 1) * end_epoch
    epoch = 0
    sess.graph.finalize()

    try:

        for step in range(MAX_STEP):
            i = i + 1
            if coord.should_stop():
                break
            image_batch_array, label_batch_array, bbox_batch_array, gesture_batch_array = sess.run(
                [image_batch, label_batch, bbox_batch, gesture_batch])
            #random flip
            image_batch_array, gesture_batch_array = random_flip_images(
                image_batch_array, label_batch_array, gesture_batch_array)

            _, _, summary = sess.run(
                [train_op, lr_op, summary_op],
                feed_dict={
                    input_image: image_batch_array,
                    label: label_batch_array,
                    bbox_target: bbox_batch_array,
                    gesture_target: gesture_batch_array
                })

            if (step + 1) % display == 0:

                cls_loss, bbox_loss, gesture_loss, accuracy = sess.run(
                    [cls_loss_op, bbox_loss_op, gesture_loss_op, accuracy_op],
                    feed_dict={
                        input_image: image_batch_array,
                        label: label_batch_array,
                        bbox_target: bbox_batch_array,
                        gesture_target: gesture_batch_array
                    })

                #total_loss = radio_cls_loss*cls_loss + radio_bbox_loss*bbox_loss + radio_gesture_loss*gesture_loss + L2_loss
                # gesture loss: %4f,
                print(
                    "%s : Step: %d/%d, accuracy: %3f, cls loss: %4f, bbox loss: %4f, gesture_loss: %4f  "
                    % (datetime.now(), step + 1, MAX_STEP, accuracy, cls_loss,
                       bbox_loss, gesture_loss))

            #save every two epochs
            if i * config.BATCH_SIZE > num * 2:
                epoch = epoch + 1
                i = 0
                path_prefix = saver.save(sess, prefix, global_step=epoch * 2)
                print('path prefix is :', path_prefix)

            writer.add_summary(summary, global_step=step)

    except tf.errors.OutOfRangeError:
        print("Finished!( ゜- ゜)つロ乾杯")
    finally:
        coord.request_stop()
        writer.close()

    coord.join(threads)
    sess.close()
Beispiel #10
0
def train(net_factory,
          prefix,
          end_epoch,
          base_dir,
          display=200,
          base_lr=0.01,
          COLOR_GRAY=0):  ###
    """
    train PNet/RNet/ONet
    :param net_factory:
    :param prefix:
    :param end_epoch:16
    :param dataset:
    :param display:
    :param base_lr:
    :return:
    """
    net = prefix.split('/')[-1]
    #label file
    #label_file = os.path.join(base_dir,'train_%s_32_landmarkcolor.txt' % net)
    #label_file = os.path.join(base_dir,'landmark_12_few.txt')
    #print (label_file )
    #f = open(label_file, 'r')
    num = 1600000  #531806  #950000   # 1500000
    #num = len(f.readlines())
    #f.close()
    print("Total datasets is: ", num)
    print(prefix)

    if net == 'PNet' and not config.SINGLEF:
        if COLOR_GRAY == 0:
            #dataset_dir = os.path.join(base_dir,'train_%s_ALL.tfrecord_shuffle' % net)
            dataset_dir = os.path.join(
                base_dir, 'train_%s_12_color.tfrecord_shuffle' % net)
        elif COLOR_GRAY == 1:
            dataset_dir = os.path.join(
                base_dir, 'train_%s_12_gray.tfrecord_shuffle' % net)
        print(dataset_dir)
        #pdb.set_trace()
        image_batch, label_batch, bbox_batch, landmark_batch = read_single_tfrecord(
            dataset_dir, config.BATCH_SIZE, net, COLOR_GRAY)
    elif net == 'PNet' and config.SINGLEF:
        if COLOR_GRAY == 0:
            pos_dir = os.path.join(base_dir,
                                   'PNet_12_color_pos.tfrecord_shuffle')
            part_dir = os.path.join(base_dir,
                                    'PNet_12_color_part.tfrecord_shuffle')
            neg_dir = os.path.join(base_dir,
                                   'PNet_12_color_neg.tfrecord_shuffle')
        elif COLOR_GRAY == 1:
            pos_dir = os.path.join(base_dir,
                                   'PNet_12_gray_pos.tfrecord_shuffle')
            part_dir = os.path.join(base_dir,
                                    'PNet_12_gray_part.tfrecord_shuffle')
            neg_dir = os.path.join(base_dir,
                                   'PNet_12_gray_neg.tfrecord_shuffle')
        landmark_dir = None
        dataset_dirs = [pos_dir, part_dir, neg_dir, landmark_dir]

        pos_radio = 1.0 / 5
        part_radio = 1.0 / 5
        landmark_radio = 1.0 / 6
        neg_radio = 3.0 / 5
        pos_batch_size = int(np.ceil(config.BATCH_SIZE * pos_radio))
        assert pos_batch_size != 0, "Batch Size Error "
        part_batch_size = int(np.ceil(config.BATCH_SIZE * part_radio))
        assert part_batch_size != 0, "Batch Size Error "
        neg_batch_size = int(np.ceil(config.BATCH_SIZE * neg_radio))
        assert neg_batch_size != 0, "Batch Size Error "
        landmark_batch_size = int(np.ceil(config.BATCH_SIZE * landmark_radio))
        assert landmark_batch_size != 0, "Batch Size Error "
        batch_sizes = [
            pos_batch_size, part_batch_size, neg_batch_size,
            landmark_batch_size
        ]
        landmarkflag = 0
        partflag = 1
        image_batch, label_batch, bbox_batch, landmark_batch = read_multi_tfrecords(
            dataset_dirs, batch_sizes, net, landmarkflag, partflag, COLOR_GRAY)
    elif net == 'RNet':
        if COLOR_GRAY == 0:
            pos_dir = os.path.join(base_dir,
                                   'RNet_24_color_pos.tfrecord_shuffle')
            part_dir = os.path.join(base_dir,
                                    'RNet_24_color_part.tfrecord_shuffle')
            neg_dir = os.path.join(base_dir,
                                   'RNet_24_color_neg.tfrecord_shuffle')
            #landmark_dir = os.path.join(base_dir,'RNet_24_color_landmark.tfrecord_shuffle')
        elif COLOR_GRAY == 1:
            pos_dir = os.path.join(base_dir,
                                   'RNet_24_gray_pos.tfrecord_shuffle')
            part_dir = os.path.join(base_dir,
                                    'RNet_24_gray_part.tfrecord_shuffle')
            neg_dir = os.path.join(base_dir,
                                   'RNet_24_gray_neg.tfrecord_shuffle')
            #landmark_dir = os.path.join(base_dir,'RNet_24_gray_landmark.tfrecord_shuffle')
        landmark_dir = None
        dataset_dirs = [pos_dir, part_dir, neg_dir, landmark_dir]
        pos_radio = 1.0 / 5
        part_radio = 1.0 / 5
        landmark_radio = 1.0 / 6
        neg_radio = 3.0 / 5
        pos_batch_size = int(np.ceil(config.BATCH_SIZE * pos_radio))
        assert pos_batch_size != 0, "Batch Size Error "
        part_batch_size = int(np.ceil(config.BATCH_SIZE * part_radio))
        assert part_batch_size != 0, "Batch Size Error "
        neg_batch_size = int(np.ceil(config.BATCH_SIZE * neg_radio))
        assert neg_batch_size != 0, "Batch Size Error "
        landmark_batch_size = int(np.ceil(config.BATCH_SIZE * landmark_radio))
        assert landmark_batch_size != 0, "Batch Size Error "
        batch_sizes = [
            pos_batch_size, part_batch_size, neg_batch_size,
            landmark_batch_size
        ]
        landmarkflag = 0  #  select landmarkflag
        partflag = 1
        image_batch, label_batch, bbox_batch, landmark_batch = read_multi_tfrecords(
            dataset_dirs, batch_sizes, net, landmarkflag, partflag, COLOR_GRAY)
    elif net == 'ONet':
        if COLOR_GRAY == 0:
            pos_dir = os.path.join(base_dir,
                                   'ONet_48_color_pos.tfrecord_shuffle')
            part_dir = os.path.join(base_dir,
                                    'ONet_48_color_part.tfrecord_shuffle')
            neg_dir = os.path.join(base_dir,
                                   'ONet_48_color_neg.tfrecord_shuffle')
            #landmark_dir = os.path.join(base_dir,'ONet_48_color_landmark.tfrecord_shuffle')
        elif COLOR_GRAY == 1:
            pos_dir = os.path.join(base_dir,
                                   'ONet_48_gray_pos.tfrecord_shuffle')
            #part_dir = os.path.join(base_dir,'ONet_48_gray_part.tfrecord_shuffle')
            neg_dir = os.path.join(base_dir,
                                   'ONet_48_gray_neg_single.tfrecord_shuffle')
            #landmark_dir = os.path.join(base_dir,'ONet_48_gray_landmark.tfrecord_shuffle')
        part_dir = None
        landmark_dir = None
        #part_dir = None
        dataset_dirs = [pos_dir, part_dir, neg_dir, landmark_dir]
        pos_radio = 1.0 / 2
        part_radio = 1.0 / 6
        landmark_radio = 1.0 / 6
        neg_radio = 1.0 / 2
        pos_batch_size = int(np.ceil(config.BATCH_SIZE * pos_radio))
        assert pos_batch_size != 0, "Batch Size Error "
        part_batch_size = int(np.ceil(config.BATCH_SIZE * part_radio))
        assert part_batch_size != 0, "Batch Size Error "
        neg_batch_size = int(np.ceil(config.BATCH_SIZE * neg_radio))
        assert neg_batch_size != 0, "Batch Size Error "
        landmark_batch_size = int(np.ceil(config.BATCH_SIZE * landmark_radio))
        assert landmark_batch_size != 0, "Batch Size Error "
        batch_sizes = [
            pos_batch_size, part_batch_size, neg_batch_size,
            landmark_batch_size
        ]
        landmarkflag = 0
        partflag = 0
        image_batch, label_batch, bbox_batch, landmark_batch = read_multi_tfrecords(
            dataset_dirs, batch_sizes, net, landmarkflag, partflag, COLOR_GRAY)

    #landmark_dir
    if net == 'PNet':
        image_size = 12
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 0.5
    elif net == 'RNet':
        image_size = 24
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 0.5
    else:
        radio_cls_loss = 1.0
        radio_bbox_loss = 0
        radio_landmark_loss = 0
        image_size = 48

    #define placeholder
    if COLOR_GRAY == 1:
        input_image = tf.placeholder(
            tf.float32,
            shape=[config.BATCH_SIZE, image_size, image_size, 1],
            name='input_image')
    else:
        input_image = tf.placeholder(
            tf.float32,
            shape=[config.BATCH_SIZE, image_size, image_size, 3],
            name='input_image')
    label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label')
    bbox_target = tf.placeholder(tf.float32,
                                 shape=[config.BATCH_SIZE, 4],
                                 name='bbox_target')
    landmark_target = tf.placeholder(tf.float32,
                                     shape=[config.BATCH_SIZE, 4],
                                     name='landmark_target')
    #class,regression
    print('class,regression+')
    '''
    cls_loss_op,bbox_loss_op,landmark_loss_op,L2_loss_op,accuracy_op = net_factory(
        input_image, label, bbox_target,landmark_target,training=True)
    '''
    cls_loss_op, L2_loss_op, accuracy_op = net_factory(input_image,
                                                       label,
                                                       bbox_target,
                                                       landmark_target,
                                                       training=True)
    #train,update learning rate(3 loss)
    # train_op, lr_op,global_step = train_model(base_lr, radio_cls_loss*cls_loss_op + radio_bbox_loss*bbox_loss_op + radio_landmark_loss*landmark_loss_op + L2_loss_op, num)
    train_op, lr_op, global_step = train_model(
        base_lr, radio_cls_loss * cls_loss_op + L2_loss_op, num)
    # init
    init = tf.global_variables_initializer()
    ###gpu
    #configp = tf.ConfigProto()
    #configp.allow_soft_placement = True
    #configp.gpu_options.per_process_gpu_memory_fraction = 0.3
    #configp.gpu_options.allow_growth = True
    #sess = tf.Session(config =configp)

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.3)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                            log_device_placement=False))
    #save model
    saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=3)
    #saver = tf.train.Saver(max_to_keep=3)
    # restore model
    if net == 'PNet':
        if COLOR_GRAY == 1:
            pretrained_model = './data/new_model/PNet_merge_gray'
        else:
            pretrained_model = './data/new_model/PNet_merge_color'
    elif net == 'RNet':
        if COLOR_GRAY == 1:
            pretrained_model = './data/new_model/RNet_cmerge_gray_gray'
        else:
            pretrained_model = './data/new_model/RNet_cmerge_gray_color'
    elif net == 'ONet':
        pretrained_model = './data/new_model/test1-1ONet_NIR_calib_A_gray'
    else:
        pretrained_model = None

    sess.run(init)
    print(sess.run(global_step))
    if pretrained_model and config.PRETRAIN:
        print('Restoring pretrained model: %s' % pretrained_model)
        ckpt = tf.train.get_checkpoint_state(pretrained_model)
        if ckpt and ckpt.model_checkpoint_path:
            print(ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            print('Not Pretrain \n')

    print(sess.run(global_step))
    #visualize some variables
    tf.summary.scalar("cls_loss", cls_loss_op)  #cls_loss
    #tf.summary.scalar("bbox_loss",bbox_loss_op)#bbox_loss
    # tf.summary.scalar("landmark_loss",landmark_loss_op)#landmark_loss
    tf.summary.scalar("cls_accuracy", accuracy_op)  #cls_acc
    summary_op = tf.summary.merge_all()
    logs_dir = "./logs/%s" % (net)
    if os.path.exists(logs_dir) == False:
        os.mkdir(logs_dir)
    writer = tf.summary.FileWriter(logs_dir, sess.graph)
    #begin
    coord = tf.train.Coordinator()
    #begin enqueue thread
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    i = 0
    #total steps
    MAX_STEP = int(num / config.BATCH_SIZE + 1) * end_epoch

    epoch = 0
    sess.graph.finalize()
    try:
        for step in range(MAX_STEP):
            i = i + 1
            if coord.should_stop():
                break
            image_batch_array, label_batch_array, bbox_batch_array, landmark_batch_array = sess.run(
                [image_batch, label_batch, bbox_batch, landmark_batch])
            #random flip
            #image_batch_array,landmark_batch_array = random_flip_images(image_batch_array,label_batch_array,landmark_batch_array)
            '''
            print image_batch_array.shape
            print label_batch_array.shape
            print bbox_batch_array.shape
            print landmark_batch_array.shape
            print label_batch_array[0]
            print bbox_batch_array[0]
            print landmark_batch_array[0]
            '''
            _, _, summary = sess.run(
                [train_op, lr_op, summary_op],
                feed_dict={
                    input_image: image_batch_array,
                    label: label_batch_array,
                    bbox_target: bbox_batch_array,
                    landmark_target: landmark_batch_array
                })

            if (step + 1) % display == 0:

                #acc = accuracy(cls_pred, labels_batch)
                cls_loss, L2_loss, lr, acc = sess.run(
                    [cls_loss_op, L2_loss_op, lr_op, accuracy_op],
                    feed_dict={
                        input_image: image_batch_array,
                        label: label_batch_array,
                        bbox_target: bbox_batch_array,
                        landmark_target: landmark_batch_array
                    })
                print(
                    "%s : Step: %d, accuracy: %3f, cls loss: %4f, L2 loss: %4f,lr:%f "
                    % (datetime.now(), step + 1, acc, cls_loss, L2_loss, lr))

            #save every two epochs
            if i * config.BATCH_SIZE > num:
                epoch = epoch + 1
                print('save epoch%d' % epoch)
                i = 0
                saver.save(sess, prefix, global_step=epoch)
            writer.add_summary(summary, global_step=step)
    except tf.errors.OutOfRangeError:
        print("完成!!!")
    finally:
        coord.request_stop()
        writer.close()
    coord.join(threads)
    sess.close()
Beispiel #11
0
def train(net_factory, prefix, end_epoch, base_dir, display=200, base_lr=0.01):
    net = prefix.split('/')[-1]
    label_file = os.path.join(base_dir, 'train_%s_landmark.txt' % net)
    f = open(label_file, 'r')
    num = len(f.readlines())
    if net == 'PNet':
        dataset_dir = os.path.join(base_dir,
                                   'train_%s_landmark.tfrecord_shuffle' % net)
        image_batch, label_batch, bbox_batch, landmark_batch = read_single_tfrecord(
            dataset_dir, config.BATCH_SIZE, net)
    # 其他网络读取3个文件
    else:
        pos_dir = os.path.join(base_dir, 'pos_landmark.tfrecord_shuffle')
        part_dir = os.path.join(base_dir, 'part_landmark.tfrecord_shuffle')
        neg_dir = os.path.join(base_dir, 'neg_landmark.tfrecord_shuffle')
        landmark_dir = os.path.join(base_dir,
                                    'landmark_landmark.tfrecord_shuffle')
        dataset_dirs = [pos_dir, part_dir, neg_dir, landmark_dir]
        pos_radio = 1.0 / 6
        part_radio = 1.0 / 6
        landmark_radio = 1.0 / 6
        neg_radio = 3.0 / 6
        pos_batch_size = int(np.ceil(config.BATCH_SIZE * pos_radio))
        part_batch_size = int(np.ceil(config.BATCH_SIZE * part_radio))
        neg_batch_size = int(np.ceil(config.BATCH_SIZE * neg_radio))
        landmark_batch_size = int(np.ceil(config.BATCH_SIZE * landmark_radio))
        batch_sizes = [
            pos_batch_size, part_batch_size, neg_batch_size,
            landmark_batch_size
        ]
        image_batch, label_batch, bbox_batch, landmark_batch = read_multi_tfrecords(
            dataset_dirs, batch_sizes, net)
    #确定损失函数之间的比例
    if net == 'PNet':
        image_size = 12
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 0.5
    elif net == 'RNet':
        image_size = 24
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 0.5
    else:
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 1.0
        image_size = 48
    input_image = tf.placeholder(
        tf.float32,
        shape=[config.BATCH_SIZE, image_size, image_size, 3],
        name='input_image')
    label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label')
    bbox_target = tf.placeholder(tf.float32,
                                 shape=[config.BATCH_SIZE, 4],
                                 name='bbox_target')
    landmark_target = tf.placeholder(tf.float32,
                                     shape=[config.BATCH_SIZE, 10],
                                     name='labdmark_target')
    # 得到分类和回归结果
    cls_loss_op, bbox_loss_op, landmark_loss_op, L2_loss_op, accuracy_op = \
        net_factory(input_image, label, bbox_target,
        landmark_target, training=True)
    # 训练优化
    train_op, lr_op = train_model(
        base_lr,
        radio_cls_loss * cls_loss_op + radio_bbox_loss * bbox_loss_op +
        radio_landmark_loss * landmark_loss_op + L2_loss_op, num)
    # 进行初始化
    init = tf.global_variables_initializer()
    sess = tf.Session()
    # 保存模型
    saver = tf.train.Saver(max_to_keep=0)
    sess.run(init)
    #开始定义同步
    coord = tf.train.Coordinator()
    #开始队列线程
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    i = 0
    #总的训练步数
    MAX_STEP = int(num / config.BATCH_SIZE + 1) * end_epoch
    epoch = 0
    try:
        for step in range(MAX_STEP):
            i = i + 1
            if coord.should_stop():
                break
            image_batch_array, label_batch_array, bbox_batch_array, landmark_batch_array = sess.run(
                [image_batch, label_batch, bbox_batch, landmark_batch])
            # random flip
            image_batch_array, landmark_batch_array = random_flip_images(
                image_batch_array, label_batch_array, landmark_batch_array)
            _, _ = sess.run(
                [train_op, lr_op],
                feed_dict={
                    input_image: image_batch_array,
                    label: label_batch_array,
                    bbox_target: bbox_batch_array,
                    landmark_target: landmark_batch_array
                })

            if (step + 1) % display == 0:
                # acc = accuracy(cls_pred, labels_batch)
                cls_loss, bbox_loss, landmark_loss, L2_loss, lr, acc = sess.run(
                    [
                        cls_loss_op, bbox_loss_op, landmark_loss_op,
                        L2_loss_op, lr_op, accuracy_op
                    ],
                    feed_dict={
                        input_image: image_batch_array,
                        label: label_batch_array,
                        bbox_target: bbox_batch_array,
                        landmark_target: landmark_batch_array
                    })
                print(
                    "%s : Step: %d, accuracy: %3f, cls loss: %4f, bbox loss: %4f, landmark loss: %4f,L2 loss: %4f,lr:%f "
                    % (datetime.now(), step + 1, acc, cls_loss, bbox_loss,
                       landmark_loss, L2_loss, lr))
            # save every two epochs
            if i * config.BATCH_SIZE > num * 2:
                epoch = epoch + 1
                i = 0
                saver.save(sess, prefix, global_step=epoch * 2)
    except tf.errors.OutOfRangeError:
        print("完成!!!")
    finally:
        coord.request_stop()
    coord.join(threads)
    sess.close()
Beispiel #12
0
def main():
    LAMBDA = 0.001
    num_class = 526
    args = argument()
    checkpoint_dir = args.save_model_name
    lr = args.lr
    batch_size = args.batch_size
    epoch_num = args.epoch_num
    sta = args.sta
    img_shape = args.img_shape
    num_gpus = 4
    #train_batch_loader = BatchLoader("./data/facescrub_train.list", batch_size,img_shape)
    #test_batch_loader = BatchLoader("./data/facescrub_val.list", batch_size,img_shape)
    #(Height,Width) = (train_batch_loader.height,train_batch_loader.width)
    #train_batch_loader = mnist_data(batch_size)
    tfrecord_file = './data/MegaFace_train.tfrecord_shuffle'
    val_file = './data/MegaFace_val.tfrecord_shuffle'
    image_batch, label_batch = read_single_tfrecord(tfrecord_file, batch_size,
                                                    img_shape)
    val_image_batch, val_label_batch = read_single_tfrecord(
        val_file, batch_size, img_shape)
    print("img shape", img_shape)
    with tf.name_scope('input'):
        input_images = tf.placeholder(tf.float32,
                                      shape=(batch_size, img_shape, img_shape,
                                             3),
                                      name='input_images')
        labels = tf.placeholder(tf.int32, shape=(batch_size), name='labels')
        learn_rate = tf.placeholder(tf.float32,
                                    shape=(None),
                                    name='learn_rate')
    with tf.name_scope('var'):
        global_step = tf.Variable(0, trainable=False, name='global_step')
    #total_loss, accuracy, centers_update_op, center_loss, softmax_loss,pred_class = build_network(input_images,labels)
    #total_loss, accuracy, centers_update_op, center_loss, softmax_loss,pred_class = make_parallel(build_network,num_gpus,input_images=input_images,labels=labels)
    total_loss = make_parallel(build_network,
                               num_gpus,
                               input_images=input_images,
                               labels=labels)
    #optimizer = tf.train.AdamOptimizer(learn_rate)
    optimizer = tf.train.GradientDescentOptimizer(learn_rate)
    #with tf.control_dependencies([centers_update_op]):
    train_op = optimizer.minimize(tf.reduce_mean(total_loss),
                                  colocate_gradients_with_ops=True)
    #train_op = optimizer.minimize(total_loss, global_step=global_step)
    summary_op = tf.summary.merge_all()

    with tf.Session(config=tf.ConfigProto(log_device_placement=False)) as sess:
        sess.run(tf.global_variables_initializer())
        writer = tf.summary.FileWriter('./tmp/face_log', sess.graph)
        saver = tf.train.Saver()
        #begin
        coord = tf.train.Coordinator()
        #begin enqueue thread
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        step = sess.run(global_step)
        epoch_idx = 0
        graph_step = 0
        item = './data/facescrub_train.list'
        imagelist = open(item, 'r')
        files_item = imagelist.readlines()
        file_len = len(files_item)
        batch_num = np.ceil(file_len / batch_size)
        while epoch_idx <= epoch_num:
            step = 0
            ckpt_fg = 'True'
            ps_loss = 0.0
            pc_loss = 0.0
            acc_sum = 0.0
            while step < batch_num:
                train_img_batch, train_label_batch = sess.run(
                    [image_batch, label_batch])
                #print("data in ",in_img[0,:2,:2,0])
                _, summary_str, Center_loss = sess.run(
                    [train_op, summary_op, total_loss],
                    feed_dict={
                        input_images: train_img_batch,
                        labels: train_label_batch,
                        learn_rate: lr
                    })
                step += 1
                #print("step",step, str(Softmax_loss),str(Center_loss))
                #print("res1",res1_o[0,:20])
                #print("step label",step, str(batch_labels))
                graph_step += 1
                if step % 10 == 0:
                    writer.add_summary(summary_str, global_step=graph_step)
                pc_loss += Center_loss
                #ps_loss+=Softmax_loss
                #acc_sum+=train_acc
                if step % 100 == 0:
                    #lr = lr*0.1
                    #c_loss+=c_loss
                    #s_loss+=s_loss
                    print("****** Epoch {} Step {}: ***********".format(
                        str(epoch_idx), str(step)))
                    print("center loss: {}".format(pc_loss / 100.0))
                    print("softmax_loss: {}".format(ps_loss / 100.0))
                    print("train_acc: {}".format(acc_sum / 100.0))
                    print("*******************************")
                    if (Center_loss < 0.1 and ckpt_fg == 'True'):
                        print(
                            "******************************************************************************"
                        )
                        saver.save(sess, checkpoint_dir, global_step=epoch_idx)
                        ckpt_fg = 'False'
                    ps_loss = 0.0
                    pc_loss = 0.0
                    acc_sum = 0.0

            epoch_idx += 1

            if epoch_idx % 5 == 0:
                print(
                    "******************************************************************************"
                )
                saver.save(sess, checkpoint_dir, global_step=epoch_idx)

            #writer.add_summary(summary_str, global_step=step)
            if epoch_idx % 5 == 0:
                lr = lr * 0.5

            if epoch_idx:
                val_img_batch, val_label_batch = sess.run(
                    [val_image_batch, val_label_batch])
                vali_acc = sess.run(total_loss,
                                    feed_dict={
                                        input_images: val_img_batch,
                                        labels: val_label_batch
                                    })
                print(("epoch: {}, train_acc:{:.4f}, vali_acc:{:.4f}".format(
                    epoch_idx, Center_loss, vali_acc)))
        coord.join(threads)
        sess.close()
Beispiel #13
0
def train(net_factory, prefix, end_epoch, base_dir, base_lr=0.01, net='PNet'):
    """
    train PNet/RNet/ONet
    :param net_factory:
    :param prefix: model path
    :param end_epoch:
    :param dataset:
    :param display:
    :param base_lr:
    :return:
    """
    net = 'PNet'
    label_file = os.path.join(base_dir, 'train_%s_landmark.txt' % net)
    print(label_file)

    f = open(label_file, 'r')
    num = len(f.readlines())
    print("Total size of the dataset is: ", num)
    print(prefix)

    if net == 'PNet':
        dataset_dir = os.path.join(base_dir,
                                   'train_%s_landmark.tfrecord' % net)
        print('dataset dir is:', dataset_dir)
        image_batch, label_batch, bbox_batch, landmark_batch = read_single_tfrecord(
            dataset_dir, config.BATCH_SIZE, net)

    if net == 'PNet':
        image_size = 12
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 0.5

    input_image = tf.placeholder(
        tf.float32,
        shape=[config.BATCH_SIZE, image_size, image_size, 3],
        name='input_image')
    label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label')
    bbox_target = tf.placeholder(tf.float32,
                                 shape=[config.BATCH_SIZE, 4],
                                 name='bbox_target')
    landmark_target = tf.placeholder(tf.float32,
                                     shape=[config.BATCH_SIZE, 10],
                                     name='landmark_target')
    #图像增强
    input_image = image_color_distort(input_image)
    cls_loss_op, bbox_loss_op, landmark_loss_op, l2_loss_op, accuracy_op = net_factory(
        input_image, label, bbox_target, landmark_target, training=True)
    total_loss_op = radio_cls_loss * cls_loss_op + radio_bbox_loss * bbox_loss_op + radio_landmark_loss * landmark_loss_op + l2_loss_op

    train_op, lr_op = train_model(base_lr, total_loss_op, num)
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        saver = tf.train.Saver()
        sess.run(init)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        i = 0
        MAX_STEP = int(num / config.BATCH_SIZE + 1) * end_epoch
        epoch = 0

        try:
            for step in range(MAX_STEP):
                i += 1
                image_batch_array, label_batch_array, bbox_batch_array, landmark_batch_array = sess.run(
                    [image_batch, label_batch, bbox_batch, landmark_batch])
                #random flip

                image_batch_array, landmark_batch_array = random_flip_images(
                    image_batch_array, label_batch_array, landmark_batch_array)
                print(label_batch_array)
                #print(image_batch_array.shape)
                #print(label_batch_array.shape)
                #print(bbox_batch_array)
                #print(landmark_batch_array)
                _, _ = sess.run(
                    [train_op, lr_op],
                    feed_dict={
                        input_image: image_batch_array,
                        label: label_batch_array,
                        bbox_target: bbox_batch_array,
                        landmark_target: landmark_batch_array
                    })

                if (step) % 50 == 0:
                    cls_loss, bbox_loss, landmark_loss, l2_loss, lr, acc = sess.run(
                        [
                            cls_loss_op, bbox_loss_op, landmark_loss_op,
                            l2_loss_op, lr_op, accuracy_op
                        ],
                        feed_dict={
                            input_image: image_batch_array,
                            label: label_batch_array,
                            bbox_target: bbox_batch_array,
                            landmark_target: landmark_batch_array
                        })

                    total_loss = radio_cls_loss * cls_loss + radio_bbox_loss * bbox_loss + radio_landmark_loss * landmark_loss + l2_loss
                    # landmark loss: %4f,
                    print(
                        "Step: %d/%d, accuracy: %3f, cls loss: %4f, bbox loss: %4f,Landmark loss :%4f,L2 loss: %4f, Total Loss: %4f ,lr:%f "
                        % (step + 1, MAX_STEP, acc, cls_loss, bbox_loss,
                           landmark_loss, l2_loss, total_loss, lr))
                #save model
                if i * config.BATCH_SIZE > num:
                    i = 0
                    epoch += 1
                    path_prefix = saver.save(sess, prefix, global_step=epoch)
                    print(path_prefix)
        except tf.errors.OutOfRangeError:
            pass
        finally:
            coord.request_stop()
        coord.join(threads)
        sess.close()