Ejemplo n.º 1
0
def train(net_factory, prefix, end_epoch, base_dir, display=200, base_lr=0.01):
    net = prefix.split('/')[-1]

    label_file = os.path.join(base_dir, 'train_%s_landmark.txt' % net)
    print(label_file
          )  # label_file = os.path.join(base_dir,'landmark_12_few.txt')
    f = open(label_file, 'r')
    num = len(f.readlines())  # 142w个 数据
    print("Total datasets is: ", num)
    print(prefix)  # prefix == '../data/MTCNN_model/PNet_landmark/PNet'

    if net == 'PNet':  #PNet use this method to get data
        # dataset_dir = '../prepare_data/imglists/PNet\\train_PNet_landmark.tfrecord_shuffle'
        dataset_dir = os.path.join(base_dir,
                                   'train_%s_landmark.tfrecord_shuffle' % net)
        print(dataset_dir)
        # 一个batch == 4608, 从数据集中读取1个batch的pixel和,label
        image_batch, label_batch, bbox_batch, landmark_batch = read_single_tfrecord(
            dataset_dir, config.BATCH_SIZE, net)

    else:  #R Net use 3 tfrecords to get data
        pos_dir = os.path.join(base_dir, 'pos_landmark.tfrecord_shuffle')
        part_dir = os.path.join(base_dir, 'part_landmark.tfrecord_shuffle')
        neg_dir = os.path.join(base_dir, 'neg_landmark.tfrecord_shuffle')
        landmark_dir = os.path.join(base_dir,
                                    'landmark_landmark.tfrecord_shuffle')
        dataset_dirs = [pos_dir, part_dir, neg_dir, landmark_dir]
        pos_radio = 1.0 / 6
        part_radio = 1.0 / 6
        landmark_radio = 1.0 / 6
        neg_radio = 3.0 / 6
        pos_batch_size = int(np.ceil(config.BATCH_SIZE * pos_radio))
        assert pos_batch_size != 0, "Batch Size Error "
        part_batch_size = int(np.ceil(config.BATCH_SIZE * part_radio))
        assert part_batch_size != 0, "Batch Size Error "
        neg_batch_size = int(np.ceil(config.BATCH_SIZE * neg_radio))
        assert neg_batch_size != 0, "Batch Size Error "
        landmark_batch_size = int(np.ceil(config.BATCH_SIZE * landmark_radio))
        assert landmark_batch_size != 0, "Batch Size Error "
        batch_sizes = [
            pos_batch_size, part_batch_size, neg_batch_size,
            landmark_batch_size
        ]
        image_batch, label_batch, bbox_batch, landmark_batch = read_multi_tfrecords(
            dataset_dirs, batch_sizes, net)

    #landmark_dir
    if net == 'PNet':
        image_size = 12
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 0.5
        # cls_prob == classify probability
    elif net == 'RNet':
        image_size = 24
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 0.5
    else:
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 1.0
        image_size = 48

    #define placeholder
    input_image = tf.placeholder(
        tf.float32,
        shape=[config.BATCH_SIZE, image_size, image_size, 3],
        name='input_image')
    label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label')
    bbox_target = tf.placeholder(tf.float32,
                                 shape=[config.BATCH_SIZE, 4],
                                 name='bbox_target')
    landmark_target = tf.placeholder(tf.float32,
                                     shape=[config.BATCH_SIZE, 10],
                                     name='landmark_target')

    global_ = tf.Variable(tf.constant(0), trainable=False)

    #前向传播 class,regression
    cls_loss_op, bbox_loss_op, landmark_loss_op, L2_loss_op, accuracy_op = net_factory(
        input_image, label, bbox_target, landmark_target, training=True)
    #train,update learning rate(3 loss),为什么要对detection和alignment的损失 * 0.5
    train_op, lr_op = train_model(
        base_lr,
        radio_cls_loss * cls_loss_op + radio_bbox_loss * bbox_loss_op +
        radio_landmark_loss * landmark_loss_op + L2_loss_op, num, global_)
    # init
    init = tf.global_variables_initializer()
    sess = tf.Session()
    #save model
    saver = tf.train.Saver(max_to_keep=0)
    sess.run(init)
    # 数据的可视化 visualize some variables
    tf.summary.scalar("cls_loss", cls_loss_op)  #cls_loss
    tf.summary.scalar("bbox_loss", bbox_loss_op)  #bbox_loss
    tf.summary.scalar("landmark_loss", landmark_loss_op)  #landmark_loss
    tf.summary.scalar("cls_accuracy", accuracy_op)  #cls_acc
    tf.summary.scalar("learning rate", lr_op)  # learning rate
    summary_op = tf.summary.merge_all()
    logs_dir = "../logs/%s" % (net)
    if os.path.exists(logs_dir) == False:
        os.mkdir(logs_dir)  # logs文件夹的

    writer = tf.summary.FileWriter(logs_dir, sess.graph)
    # begin
    coord = tf.train.Coordinator()
    # begin enqueue thread 启动多个工作线程同时将多个tensor(训练数据)推送入文件名称队列中
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    i = 0
    #total steps
    MAX_STEP = int(num / config.BATCH_SIZE + 1) * end_epoch  # end_epoch训练周期?
    epoch = 0
    # 防止内存溢出?
    sess.graph.finalize()
    try:
        for step in range(MAX_STEP):
            i = i + 1
            # 使用 coord.should_stop()来查询是否应该终止所有线程,当文件队列(queue)中的所有文件都已经读取出列的时候,会抛出一个OutofRangeError的异常,这时候就应该停止Sesson中的所有线程
            if coord.should_stop():
                break
            image_batch_array, label_batch_array, bbox_batch_array, landmark_batch_array = sess.run(
                [image_batch, label_batch, bbox_batch, landmark_batch])
            #随机翻转图片random flip
            image_batch_array, landmark_batch_array = random_flip_images(
                image_batch_array, label_batch_array, landmark_batch_array)
            # 计算总的损失
            _, _, summary = sess.run(
                [train_op, lr_op, summary_op],
                feed_dict={
                    input_image: image_batch_array,
                    label: label_batch_array,
                    bbox_target: bbox_batch_array,
                    landmark_target: landmark_batch_array,
                    global_: step
                })
            # 每过100轮,打印loss
            if (step + 1) % display == 0:
                cls_loss, bbox_loss, landmark_loss, L2_loss, lr, acc = sess.run(
                    [
                        cls_loss_op, bbox_loss_op, landmark_loss_op,
                        L2_loss_op, lr_op, accuracy_op
                    ],
                    feed_dict={
                        input_image: image_batch_array,
                        label: label_batch_array,
                        bbox_target: bbox_batch_array,
                        landmark_target: landmark_batch_array,
                        global_: step
                    })

                print(
                    "%s : Step: %d, accuracy: %3f, cls loss: %4f, bbox loss: %4f, landmark loss: %4f,L2 loss: %4f, lr: %4f"
                    % (datetime.now(), step + 1, acc, cls_loss, bbox_loss,
                       landmark_loss, L2_loss, lr))
            # 每两个周期保存一次
            if i * config.BATCH_SIZE > num * 2:
                epoch = epoch + 1
                i = 0
                saver.save(sess, prefix, global_step=epoch * 2)
            if (step + 1) % display == 0:
                writer.add_summary(summary, global_step=step)
    except tf.errors.OutOfRangeError:
        print("完成!!!")
    finally:
        coord.request_stop()
        writer.close()
    coord.join(threads)
    sess.close()
Ejemplo n.º 2
0
def train(net_factory, prefix, end_epoch, base_dir,
          display=200, base_lr=0.01):
    """
    train PNet/RNet/ONet
    :param net_factory:
    :param prefix:
    :param end_epoch:16
    :param dataset:
    :param display:
    :param base_lr:
    :return:
    """
    net = prefix.split('/')[-1]
    #label file
    label_file = os.path.join(base_dir,'train_%s_landmark.txt' % net)
    #label_file = os.path.join(base_dir,'landmark_12_few.txt')
    print (label_file)
    f = open(label_file, 'r')
    num = len(f.readlines())
    print("Total datasets is: ", num)
    print (prefix)

    #PNet use this method to get data
    if net == 'PNet':
        #dataset_dir = os.path.join(base_dir,'train_%s_ALL.tfrecord_shuffle' % net)
        dataset_dir = os.path.join(base_dir,'train_%s_landmark.tfrecord_shuffle' % net)
        print (dataset_dir)
        image_batch, label_batch, bbox_batch,landmark_batch = read_single_tfrecord(dataset_dir, config.BATCH_SIZE, net)
        
    #RNet use 4 tfrecords to get data    
    else:
        pos_dir = os.path.join(base_dir,'pos_landmark.tfrecord_shuffle')
        part_dir = os.path.join(base_dir,'part_landmark.tfrecord_shuffle')
        neg_dir = os.path.join(base_dir,'neg_landmark.tfrecord_shuffle')
        landmark_dir = os.path.join(base_dir,'landmark_landmark.tfrecord_shuffle')
        dataset_dirs = [pos_dir,part_dir,neg_dir,landmark_dir]
        pos_radio = 1.0/6;part_radio = 1.0/6;landmark_radio=1.0/6;neg_radio=3.0/6
        pos_batch_size = int(np.ceil(config.BATCH_SIZE*pos_radio))
        assert pos_batch_size != 0,"Batch Size Error "
        part_batch_size = int(np.ceil(config.BATCH_SIZE*part_radio))
        assert part_batch_size != 0,"Batch Size Error "        
        neg_batch_size = int(np.ceil(config.BATCH_SIZE*neg_radio))
        assert neg_batch_size != 0,"Batch Size Error "
        landmark_batch_size = int(np.ceil(config.BATCH_SIZE*landmark_radio))
        assert landmark_batch_size != 0,"Batch Size Error "
        batch_sizes = [pos_batch_size,part_batch_size,neg_batch_size,landmark_batch_size]
        image_batch, label_batch, bbox_batch,landmark_batch = read_multi_tfrecords(dataset_dirs,batch_sizes, net)        
        
    #landmark_dir    
    if net == 'PNet':
        image_size = 12
        radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 0.5;
    elif net == 'RNet':
        image_size = 24
        radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 0.5;
    else:
        radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 1.0;
        image_size = 48
    
    #define placeholder
    input_image = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, image_size, image_size, 3], name='input_image')
    label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label')
    bbox_target = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, 4], name='bbox_target')
    landmark_target = tf.placeholder(tf.float32,shape=[config.BATCH_SIZE,10],name='landmark_target')
    #class,regression
    cls_loss_op,bbox_loss_op,landmark_loss_op,L2_loss_op,accuracy_op = net_factory(input_image, label, bbox_target,landmark_target,training=True)
    #train,update learning rate(3 loss)
    train_op, lr_op = train_model(base_lr, radio_cls_loss*cls_loss_op + radio_bbox_loss*bbox_loss_op + radio_landmark_loss*landmark_loss_op + L2_loss_op, num)
    # init
    init = tf.global_variables_initializer()
    sess = tf.Session()
    #save model
    saver = tf.train.Saver(max_to_keep=0)
    sess.run(init)
    #visualize some variables
    tf.summary.scalar("cls_loss",cls_loss_op)#cls_loss
    tf.summary.scalar("bbox_loss",bbox_loss_op)#bbox_loss
    tf.summary.scalar("landmark_loss",landmark_loss_op)#landmark_loss
    tf.summary.scalar("cls_accuracy",accuracy_op)#cls_acc
    summary_op = tf.summary.merge_all()
    logs_dir = "../logs/%s" %(net)
    if os.path.exists(logs_dir) == False:
        os.mkdir(logs_dir)
    writer = tf.summary.FileWriter(logs_dir,sess.graph)
    #begin 
    coord = tf.train.Coordinator()
    #begin enqueue thread
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    i = 0
    #total steps
    MAX_STEP = int(num / config.BATCH_SIZE + 1) * end_epoch
    epoch = 0
    sess.graph.finalize()    
    try:
        for step in range(MAX_STEP):
            i = i + 1
            if coord.should_stop():
                break
            image_batch_array, label_batch_array, bbox_batch_array,landmark_batch_array = sess.run([image_batch, label_batch, bbox_batch,landmark_batch])
            #random flip
            image_batch_array,landmark_batch_array = random_flip_images(image_batch_array,label_batch_array,landmark_batch_array)
            '''
            print image_batch_array.shape
            print label_batch_array.shape
            print bbox_batch_array.shape
            print landmark_batch_array.shape
            print label_batch_array[0]
            print bbox_batch_array[0]
            print landmark_batch_array[0]
            '''
            _,_,summary = sess.run([train_op, lr_op ,summary_op], feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array,landmark_target:landmark_batch_array})
            
            if (step+1) % display == 0:
                #acc = accuracy(cls_pred, labels_batch)
                cls_loss, bbox_loss,landmark_loss,L2_loss,lr,acc = sess.run([cls_loss_op, bbox_loss_op,landmark_loss_op,L2_loss_op,lr_op,accuracy_op],
                                                             feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, landmark_target: landmark_batch_array})                
                print("%s : Step: %d, accuracy: %3f, cls loss: %4f, bbox loss: %4f, landmark loss: %4f,L2 loss: %4f,lr:%f " % (
                datetime.now(), step+1, acc, cls_loss, bbox_loss, landmark_loss, L2_loss, lr))
            #save every two epochs
            if i * config.BATCH_SIZE > num*2:
                epoch = epoch + 1
                i = 0
                saver.save(sess, prefix, global_step=epoch*2)
            writer.add_summary(summary,global_step=step)
    except tf.errors.OutOfRangeError:
        print("完成!!!")
    finally:
        coord.request_stop()
        writer.close()
    coord.join(threads)
    sess.close()
Ejemplo n.º 3
0
def train(net_factory, prefix, end_epoch, base_dir,
          display=100, base_lr=0.01, with_gesture = False):

    """
    train PNet/RNet/ONet
    :param net_factory: a function defined in mtcnn_model.py
    :param prefix: model path
    :param end_epoch:
    :param dataset:
    :param display:
    :param base_lr:
    :return:

    """
    net = prefix.split('/')[-1]
    #label file
    label_file = os.path.join(base_dir,'train_%s_gesture.txt' % net)
    #label_file = os.path.join(base_dir,'gesture_12_few.txt')
    print("--------------------------------------------------------")
    print(label_file)
    f = open(label_file, 'r')
    # get number of training examples

    num = len(f.readlines())
    print("number of training examples: ", num)
    
    print("Total size of the dataset is: ", num)
    print(prefix)
    print("--------------------------------------------------------")

    #PNet use this method to get data
    if net == 'PNet':
        #dataset_dir = os.path.join(base_dir,'train_%s_ALL.tfrecord_shuffle' % net)
        dataset_dir = os.path.join(base_dir,'train_%s_gesture.tfrecord_shuffle' % net)
        print('dataset dir is:',dataset_dir)
        image_batch, label_batch, bbox_batch, gesture_batch = read_single_tfrecord(dataset_dir, config.BATCH_SIZE, net)

    #RNet use 3 tfrecords to get data    
    else:
        pos_dir = os.path.join(base_dir,'train_%s_pos_gesture.tfrecord_shuffle' % net)
        part_dir = os.path.join(base_dir,'train_%s_part_gesture.tfrecord_shuffle' % net)
        neg_dir = os.path.join(base_dir,'train_%s_neg_gesture.tfrecord_shuffle' % net)
        gesture_dir = os.path.join(base_dir,'train_%s_gesture.tfrecord_shuffle' % net)
        dataset_dirs = [pos_dir,part_dir,neg_dir,gesture_dir]
        pos_radio = 1.0/6;part_radio = 1.0/6;gesture_radio=1.0/6;neg_radio=3.0/6
        pos_batch_size = int(np.ceil(config.BATCH_SIZE*pos_radio))
        assert pos_batch_size != 0,"Batch Size Error "
        part_batch_size = int(np.ceil(config.BATCH_SIZE*part_radio))
        assert part_batch_size != 0,"Batch Size Error "
        neg_batch_size = int(np.ceil(config.BATCH_SIZE*neg_radio))
        assert neg_batch_size != 0,"Batch Size Error "
        gesture_batch_size = int(np.ceil(config.BATCH_SIZE*gesture_radio))
        assert gesture_batch_size != 0,"Batch Size Error "
        if (pos_batch_size+part_batch_size+neg_batch_size+gesture_batch_size)>config.BATCH_SIZE:
            gesture_batch_size = config.BATCH_SIZE - pos_batch_size - part_batch_size - neg_batch_size 
        assert pos_batch_size + part_batch_size + neg_batch_size + gesture_batch_size == config.BATCH_SIZE, "num exceeded batchsize"
        batch_sizes = [pos_batch_size,part_batch_size,neg_batch_size,gesture_batch_size]
        #print('batch_size is:', batch_sizes)
        image_batch, label_batch, bbox_batch, gesture_batch = read_multi_tfrecords(dataset_dirs,batch_sizes, net)        
        
        # if we make sure that no need for multiple tfrecords we can remove the if-else statement
        # dataset_dir = os.path.join(base_dir,'train_%s_gesture.tfrecord_shuffle' % net)
        #print('dataset dir is:',dataset_dir)
    
    # image_batch, label_batch, bbox_batch, gesture_batch = read_single_tfrecord(dataset_dir, config.BATCH_SIZE, net)
        # 
    #gesture_dir    
    if net == 'PNet':
        image_size = 12
        radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_gesture_loss = 0.5
    elif net == 'RNet':
        image_size = 24
        radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_gesture_loss = 0.5
    else:
        image_size = 48
        display = 50
        radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_gesture_loss = 1
        
    #define placeholders
    input_image = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, image_size, image_size, 3], name='input_image')
    label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label')
    bbox_target = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, 4], name='bbox_target')
    gesture_target = tf.placeholder(tf.float32,shape=[config.BATCH_SIZE,3],name='gesture_target')
    #get loss and accuracy
    # print(bbox_target)
    # print(gesture_target)
    input_image = image_color_distort(input_image)
    cls_loss_op,bbox_loss_op,gesture_loss_op,L2_loss_op,accuracy_op = net_factory(input_image, label, bbox_target,gesture_target,training=True)
    if with_gesture:
        total_loss_op  = radio_cls_loss*cls_loss_op + radio_bbox_loss*bbox_loss_op + radio_gesture_loss*gesture_loss_op + L2_loss_op
    else:
        total_loss_op  = radio_cls_loss*cls_loss_op + radio_bbox_loss*bbox_loss_op + L2_loss_op
    #train,update learning rate(3 loss)
    
    train_op, lr_op = train_model(base_lr,
                                  total_loss_op,
                                  num)
    # init
    init = tf.global_variables_initializer()
    sess = tf.Session()

    #save model
    saver = tf.train.Saver(max_to_keep=1)
    sess.run(init)

    #visualize some variables
    tf.summary.scalar("cls_accuracy",accuracy_op)#cls_acc
    tf.summary.scalar("cls_loss",cls_loss_op)#cls_loss
    tf.summary.scalar("bbox_loss",bbox_loss_op)#bbox_loss
    if with_gesture:
        tf.summary.scalar("gesture_loss",gesture_loss_op)#gesture_loss
    tf.summary.scalar("total_loss",total_loss_op)#cls_loss, bbox loss, gesture loss and L2 loss add together
    tf.summary.scalar("learn_rate",lr_op)#logging learning rate
    summary_op = tf.summary.merge_all()

    time = 'train-{}-{date:%Y-%m-%d_%H:%M:%S}'.format(net, date=datetime.now() )
    print("-------------------------------------------------------------\n")
    print("the configuration is as follows:")
    print("base_lr: {}  lr_factor: {}  end_epoch: {}".format(base_dir, LR_FACTOR, end_epoch))
    print("the sub dir's name is: ", time)
    print("-------------------------------------------------------------\n")
    logs_dir = "../logs/%s/" %(net)
    logs_dir = logs_dir + time + "/"
    if os.path.exists(logs_dir) == False:
        os.makedirs(logs_dir)
    writer = tf.summary.FileWriter(logs_dir,sess.graph)
    projector_config = projector.ProjectorConfig()
    projector.visualize_embeddings(writer,projector_config)
    #begin 
    coord = tf.train.Coordinator()
    #begin enqueue thread
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    i = 0
    #total steps
    MAX_STEP = int(num / config.BATCH_SIZE + 1) * end_epoch
    print("max_step: ", MAX_STEP)
    epoch = 0
    sess.graph.finalize()

    try:

        for step in range(MAX_STEP):
            i = i + 1
            if coord.should_stop():
                break
            image_batch_array, label_batch_array, bbox_batch_array,gesture_batch_array = sess.run([image_batch, label_batch, bbox_batch,gesture_batch])
            # image_batch_array, label_batch_array, bbox_batch_array = sess.run([image_batch, label_batch, bbox_batch])
            #random flip
            image_batch_array, gesture_batch_array = random_flip_images(image_batch_array,label_batch_array, gesture_batch_array)
            # image_batch_array, _ = random_flip_images(image_batch_array,label_batch_array)
            '''
            print('im here')
            print(image_batch_array.shape)
            print(label_batch_array.shape)
            print(bbox_batch_array.shape)
            
            print(label_batch_array[0])
            print(bbox_batch_array[0])
            print(gesture_batch_array[0])
            '''
            # print(gesture_batch_array.shape)
            _,_,summary = sess.run([train_op, lr_op ,summary_op], feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array,gesture_target:gesture_batch_array})
            # _,_,summary = sess.run([train_op, lr_op ,summary_op], feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array})

            if (step+1) % display == 0:
                #acc = accuracy(cls_pred, labels_batch)
                if with_gesture:
                    cls_loss, bbox_loss,gesture_loss,L2_loss,lr,acc = sess.run([cls_loss_op, bbox_loss_op,gesture_loss_op,L2_loss_op,lr_op,accuracy_op],
                                                                 feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, gesture_target: gesture_batch_array})
                    total_loss = radio_cls_loss*cls_loss + radio_bbox_loss*bbox_loss + radio_gesture_loss*gesture_loss + L2_loss
                    print("%s : Step: %d/%d, accuracy: %3f, cls loss: %4f, bbox loss: %4f,gesture loss :%4f,regularisation loss: %4f, Total Loss: %4f ,lr:%f " % (
                    datetime.now(), step+1,MAX_STEP, acc, cls_loss, bbox_loss,gesture_loss, L2_loss,total_loss, lr))
                else: # without gesture loss
                    cls_loss, bbox_loss,L2_loss,lr,acc = sess.run([cls_loss_op, bbox_loss_op,L2_loss_op,lr_op,accuracy_op],
                                                     feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, gesture_target: gesture_batch_array})
                    total_loss = radio_cls_loss*cls_loss + radio_bbox_loss*bbox_loss + L2_loss
                    print("%s : Step: %d/%d, accuracy: %3f, cls loss: %4f, bbox loss: %4f,regularisation loss: %4f, Total Loss: %4f ,lr:%f " % (
                    datetime.now(), step+1,MAX_STEP, acc, cls_loss, bbox_loss, L2_loss,total_loss, lr))

            #save every two epochs
            if i * config.BATCH_SIZE > num*2:
                epoch = epoch + 1
                i = 0
                path_prefix = saver.save(sess, prefix, global_step=epoch*2)
                print('path prefix is :', path_prefix)
            writer.add_summary(summary,global_step=step)

    except tf.errors.OutOfRangeError:
        print("Finished!")
    finally:
        coord.request_stop()
        writer.close()
    coord.join(threads)
    sess.close()
Ejemplo n.º 4
0
def train(net_factory, prefix, end_epoch, base_dir,
          display=200, base_lr=0.01):
    """
    train PNet/RNet/ONet
    :param net_factory:
    :param prefix:
    :param end_epoch:16
    :param dataset:
    :param display:
    :param base_lr:
    :return:
    """
    net = prefix.split('/')[-1]
    #label file
    label_file = os.path.join(base_dir,'train_%s_landmark.txt' % net)
    #label_file = os.path.join(base_dir,'landmark_12_few.txt')
    print label_file 
    f = open(label_file, 'r')
    num = len(f.readlines())
    print("Total datasets is: ", num)
    print prefix

    #PNet use this method to get data
    if net == 'PNet':
        #dataset_dir = os.path.join(base_dir,'train_%s_ALL.tfrecord_shuffle' % net)
        dataset_dir = os.path.join(base_dir,'train_%s_landmark.tfrecord_shuffle' % net)
        print dataset_dir
        image_batch, label_batch, bbox_batch,landmark_batch = read_single_tfrecord(dataset_dir, config.BATCH_SIZE, net)
        
    #RNet use 3 tfrecords to get data    
    else:
        pos_dir = os.path.join(base_dir,'pos_landmark.tfrecord_shuffle')
        part_dir = os.path.join(base_dir,'part_landmark.tfrecord_shuffle')
        neg_dir = os.path.join(base_dir,'neg_landmark.tfrecord_shuffle')
        landmark_dir = os.path.join(base_dir,'landmark_landmark.tfrecord_shuffle')
        dataset_dirs = [pos_dir,part_dir,neg_dir,landmark_dir]
        pos_radio = 1.0/6;part_radio = 1.0/6;landmark_radio=1.0/6;neg_radio=3.0/6
        pos_batch_size = int(np.ceil(config.BATCH_SIZE*pos_radio))
        assert pos_batch_size != 0,"Batch Size Error "
        part_batch_size = int(np.ceil(config.BATCH_SIZE*part_radio))
        assert part_batch_size != 0,"Batch Size Error "        
        neg_batch_size = int(np.ceil(config.BATCH_SIZE*neg_radio))
        assert neg_batch_size != 0,"Batch Size Error "
        landmark_batch_size = int(np.ceil(config.BATCH_SIZE*landmark_radio))
        assert landmark_batch_size != 0,"Batch Size Error "
        batch_sizes = [pos_batch_size,part_batch_size,neg_batch_size,landmark_batch_size]
        image_batch, label_batch, bbox_batch,landmark_batch = read_multi_tfrecords(dataset_dirs,batch_sizes, net)        
        
    #landmark_dir    
    if net == 'PNet':
        image_size = 12
        radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 0.5;
    elif net == 'RNet':
        image_size = 24
        radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 0.5;
    else:
        radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 1.0;
        image_size = 48
    
    #define placeholder
    input_image = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, image_size, image_size, 3], name='input_image')
    label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label')
    bbox_target = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE, 4], name='bbox_target')
    landmark_target = tf.placeholder(tf.float32,shape=[config.BATCH_SIZE,10],name='landmark_target')
    #class,regression
    cls_loss_op,bbox_loss_op,landmark_loss_op,L2_loss_op,accuracy_op = net_factory(input_image, label, bbox_target,landmark_target,training=True)
    #train,update learning rate(3 loss)
    train_op, lr_op = train_model(base_lr, radio_cls_loss*cls_loss_op + radio_bbox_loss*bbox_loss_op + radio_landmark_loss*landmark_loss_op + L2_loss_op, num)
    # init
    init = tf.global_variables_initializer()
    sess = tf.Session()
    #save model
    saver = tf.train.Saver(max_to_keep=0)
    sess.run(init)
    #visualize some variables
    tf.summary.scalar("cls_loss",cls_loss_op)#cls_loss
    tf.summary.scalar("bbox_loss",bbox_loss_op)#bbox_loss
    tf.summary.scalar("landmark_loss",landmark_loss_op)#landmark_loss
    tf.summary.scalar("cls_accuracy",accuracy_op)#cls_acc
    summary_op = tf.summary.merge_all()
    logs_dir = "../logs/%s" %(net)
    if os.path.exists(logs_dir) == False:
        os.mkdir(logs_dir)
    writer = tf.summary.FileWriter(logs_dir,sess.graph)
    #begin 
    coord = tf.train.Coordinator()
    #begin enqueue thread
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    i = 0
    #total steps
    MAX_STEP = int(num / config.BATCH_SIZE + 1) * end_epoch
    epoch = 0
    sess.graph.finalize()    
    try:
        for step in range(MAX_STEP):
            i = i + 1
            if coord.should_stop():
                break
            image_batch_array, label_batch_array, bbox_batch_array,landmark_batch_array = sess.run([image_batch, label_batch, bbox_batch,landmark_batch])
            #random flip
            image_batch_array,landmark_batch_array = random_flip_images(image_batch_array,label_batch_array,landmark_batch_array)
            '''
            print image_batch_array.shape
            print label_batch_array.shape
            print bbox_batch_array.shape
            print landmark_batch_array.shape
            print label_batch_array[0]
            print bbox_batch_array[0]
            print landmark_batch_array[0]
            '''
            _,_,summary = sess.run([train_op, lr_op ,summary_op], feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array,landmark_target:landmark_batch_array})
            
            if (step+1) % display == 0:
                #acc = accuracy(cls_pred, labels_batch)
                cls_loss, bbox_loss,landmark_loss,L2_loss,lr,acc = sess.run([cls_loss_op, bbox_loss_op,landmark_loss_op,L2_loss_op,lr_op,accuracy_op],
                                                             feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, landmark_target: landmark_batch_array})                
                print("%s : Step: %d, accuracy: %3f, cls loss: %4f, bbox loss: %4f, landmark loss: %4f,L2 loss: %4f,lr:%f " % (
                datetime.now(), step+1, acc, cls_loss, bbox_loss, landmark_loss, L2_loss, lr))
            #save every two epochs
            if i * config.BATCH_SIZE > num*2:
                epoch = epoch + 1
                i = 0
                saver.save(sess, prefix, global_step=epoch*2)
            writer.add_summary(summary,global_step=step)
    except tf.errors.OutOfRangeError:
        print("完成!!!")
    finally:
        coord.request_stop()
        writer.close()
    coord.join(threads)
    sess.close()
Ejemplo n.º 5
0
def train(net_factory, prefix, end_epoch, base_dir, display=200, base_lr=0.01):
    """
    train PNet/RNet/ONet
    :param net_factory:
    :param prefix: model path
    :param end_epoch:
    :param dataset:
    :param display:
    :param base_lr:
    :return:
    """
    before = 16
    net = prefix.split('/')[-1]
    #label file
    label_file = os.path.join(base_dir, 'train_%s_landmark.txt' % net)
    #label_file = os.path.join(base_dir,'landmark_12_few.txt')
    print(label_file)
    f = open(label_file, 'r')
    # get number of training examples
    num = len(f.readlines())
    print("Total size of the dataset is: ", num)
    print(prefix)

    #PNet use this method to get data
    if net == 'PNet':
        #dataset_dir = os.path.join(base_dir,'train_%s_ALL.tfrecord_shuffle' % net)
        dataset_dir = os.path.join(base_dir,
                                   'train_%s_landmark.tfrecord_shuffle' % net)
        print('dataset dir is:', dataset_dir)
        image_batch, label_batch, bbox_batch, landmark_batch = read_single_tfrecord(
            dataset_dir, config.BATCH_SIZE, net)

    #RNet use 3 tfrecords to get data
    else:
        pos_dir = os.path.join(base_dir, 'pos_landmark.tfrecord_shuffle')
        part_dir = os.path.join(base_dir, 'part_landmark.tfrecord_shuffle')
        neg_dir = os.path.join(base_dir, 'neg_landmark.tfrecord_shuffle')
        #landmark_dir = os.path.join(base_dir,'landmark_landmark.tfrecord_shuffle')
        landmark_dir = os.path.join(base_dir,
                                    'landmark_landmark.tfrecord_shuffle')
        dataset_dirs = [pos_dir, part_dir, neg_dir, landmark_dir]
        pos_radio = 1.0 / 6
        part_radio = 1.0 / 6
        landmark_radio = 1.0 / 6
        neg_radio = 3.0 / 6
        pos_batch_size = int(np.ceil(config.BATCH_SIZE * pos_radio))
        assert pos_batch_size != 0, "Batch Size Error "
        part_batch_size = int(np.ceil(config.BATCH_SIZE * part_radio))
        assert part_batch_size != 0, "Batch Size Error "
        neg_batch_size = int(np.ceil(config.BATCH_SIZE * neg_radio))
        assert neg_batch_size != 0, "Batch Size Error "
        landmark_batch_size = int(np.ceil(config.BATCH_SIZE * landmark_radio))
        assert landmark_batch_size != 0, "Batch Size Error "
        batch_sizes = [
            pos_batch_size, part_batch_size, neg_batch_size,
            landmark_batch_size
        ]
        #print('batch_size is:', batch_sizes)
        image_batch, label_batch, bbox_batch, landmark_batch = read_multi_tfrecords(
            dataset_dirs, batch_sizes, net)

    #landmark_dir
    if net == 'PNet':
        image_size = 12
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 0.5
    elif net == 'RNet':
        image_size = 24
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 0.5
    else:
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 1
        image_size = 48

    #define placeholder
    input_image = tf.placeholder(
        tf.float32,
        shape=[config.BATCH_SIZE, image_size, image_size, 3],
        name='input_image')
    label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label')
    bbox_target = tf.placeholder(tf.float32,
                                 shape=[config.BATCH_SIZE, 4],
                                 name='bbox_target')
    landmark_target = tf.placeholder(tf.float32,
                                     shape=[config.BATCH_SIZE, 10],
                                     name='landmark_target')
    #get loss and accuracy
    input_image = image_color_distort(input_image)
    cls_loss_op, bbox_loss_op, landmark_loss_op, L2_loss_op, accuracy_op = net_factory(
        input_image, label, bbox_target, landmark_target, training=True)
    #train,update learning rate(3 loss)
    total_loss_op = radio_cls_loss * cls_loss_op + radio_bbox_loss * bbox_loss_op + radio_landmark_loss * landmark_loss_op + L2_loss_op
    train_op, lr_op = train_model(base_lr, total_loss_op, num)
    # init
    #init = tf.global_variables_initializer()
    sess = tf.Session()

    #save model
    saver = tf.train.Saver(max_to_keep=0)
    #sess.run(init)
    saver.restore(sess, prefix + '-%d' % before)
    #visualize some variables
    tf.summary.scalar("cls_loss", cls_loss_op)  #cls_loss
    tf.summary.scalar("bbox_loss", bbox_loss_op)  #bbox_loss
    tf.summary.scalar("landmark_loss", landmark_loss_op)  #landmark_loss
    tf.summary.scalar("cls_accuracy", accuracy_op)  #cls_acc
    tf.summary.scalar(
        "total_loss", total_loss_op
    )  #cls_loss, bbox loss, landmark loss and L2 loss add together
    summary_op = tf.summary.merge_all()
    logs_dir = "../logs/%s" % (net)
    if os.path.exists(logs_dir) == False:
        os.mkdir(logs_dir)
    writer = tf.summary.FileWriter(logs_dir, sess.graph)
    projector_config = projector.ProjectorConfig()
    projector.visualize_embeddings(writer, projector_config)
    #begin
    coord = tf.train.Coordinator()
    #begin enqueue thread
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    i = 0
    before_step = 67800
    #total steps
    MAX_STEP = int(num / config.BATCH_SIZE + 1) * (end_epoch - before)
    epoch = 16
    sess.graph.finalize()
    try:

        for step in range(before_step, before_step + MAX_STEP):
            i = i + 1
            if coord.should_stop():
                break
            image_batch_array, label_batch_array, bbox_batch_array, landmark_batch_array = sess.run(
                [image_batch, label_batch, bbox_batch, landmark_batch])
            #random flip
            image_batch_array, landmark_batch_array = random_flip_images(
                image_batch_array, label_batch_array, landmark_batch_array)
            '''
            print('im here')
            print(image_batch_array.shape)
            print(label_batch_array.shape)
            print(bbox_batch_array.shape)
            print(landmark_batch_array.shape)
            print(label_batch_array[0])
            print(bbox_batch_array[0])
            print(landmark_batch_array[0])
            '''

            _, _, summary = sess.run(
                [train_op, lr_op, summary_op],
                feed_dict={
                    input_image: image_batch_array,
                    label: label_batch_array,
                    bbox_target: bbox_batch_array,
                    landmark_target: landmark_batch_array
                })

            if (step + 1) % display == 0:
                #acc = accuracy(cls_pred, labels_batch)
                cls_loss, bbox_loss, landmark_loss, L2_loss, lr, acc = sess.run(
                    [
                        cls_loss_op, bbox_loss_op, landmark_loss_op,
                        L2_loss_op, lr_op, accuracy_op
                    ],
                    feed_dict={
                        input_image: image_batch_array,
                        label: label_batch_array,
                        bbox_target: bbox_batch_array,
                        landmark_target: landmark_batch_array
                    })

                total_loss = radio_cls_loss * cls_loss + radio_bbox_loss * bbox_loss + radio_landmark_loss * landmark_loss + L2_loss
                # landmark loss: %4f,
                eval_result = "%s : Step: %d/%d, accuracy: %3f, cls loss: %4f, bbox loss: %4f,Landmark loss :%4f,L2 loss: %4f, Total Loss: %4f ,lr:%f " % (
                    datetime.now(), step + 1, MAX_STEP + before_step, acc,
                    cls_loss, bbox_loss, landmark_loss, L2_loss, total_loss,
                    lr)
                print(eval_result + '\n')
                with open(EVALRESULTFILE, 'a') as f:
                    f.write(eval_result + '\n')

            #save every two epochs
            if i * config.BATCH_SIZE > num:
                epoch = epoch + 1
                i = 0
                path_prefix = saver.save(sess, prefix, global_step=epoch)
                print('path prefix is :', path_prefix)
                with open(EVALRESULTFILE, 'a') as f:
                    f.write('path prefix is :' + path_prefix + '\n')
            writer.add_summary(summary, global_step=step)
    except tf.errors.OutOfRangeError:
        print("完成!!!")
    finally:
        coord.request_stop()
        writer.close()
    coord.join(threads)
    sess.close()
Ejemplo n.º 6
0
def train(net_factory,
          prefix,
          load_epoch,
          end_epoch,
          base_dir,
          display=200,
          base_lr=0.01,
          gpu_ctx='/device:GPU:0'):
    """
    train PNet/RNet/ONet
    :param net_factory:
    :param prefix:
    :param end_epoch:16
    :param dataset:
    :param display:
    :param base_lr:
    :return:
    """
    net = prefix.split('/')[-1]
    #label file
    label_file = os.path.join(base_dir, 'train_%s_landmark.txt' % net)
    #label_file = os.path.join(base_dir,'landmark_12_few.txt')
    #print label_file
    f = open(label_file, 'r')
    num = len(f.readlines())
    print("Total datasets is: ", num)
    print("saved prefix: ", prefix)

    #PNet use this method to get data
    if net == 'PNet':
        #dataset_dir = os.path.join(base_dir,'train_%s_ALL.tfrecord_shuffle' % net)
        dataset_dir = os.path.join(base_dir,
                                   'train_%s_landmark.tfrecord_shuffle' % net)
        print("dataset saved dir: ", dataset_dir)
        image_batch, label_batch, bbox_batch, landmark_batch = read_single_tfrecord(
            dataset_dir, config.BATCH_SIZE, net)

    #RNet use 3 tfrecords to get data
    else:
        pos_dir = os.path.join(base_dir, 'pos_landmark.tfrecord_shuffle')
        part_dir = os.path.join(base_dir, 'part_landmark.tfrecord_shuffle')
        neg_dir = os.path.join(base_dir, 'neg_landmark.tfrecord_shuffle')
        if train_face:
            landmark_dir = os.path.join(base_dir,
                                        'landmark_landmark.tfrecord_shuffle')
        else:
            landmark_dir = None
        dataset_dirs = [pos_dir, part_dir, neg_dir, landmark_dir]
        if train_face:
            pos_radio = 1.0 / 6
            part_radio = 1.0 / 6
            landmark_radio = 1.0 / 6
            neg_radio = 3.0 / 6
        else:
            pos_radio = 2.0 / 3
            part_radio = 1.0 / 6
            landmark_radio = 0
        pos_batch_size = int(np.ceil(config.BATCH_SIZE * pos_radio))
        assert pos_batch_size != 0, "Batch Size Error "
        part_batch_size = int(np.ceil(config.BATCH_SIZE * part_radio))
        assert part_batch_size != 0, "Batch Size Error "
        if train_face:
            neg_batch_size = int(np.ceil(config.BATCH_SIZE * neg_radio))
            assert neg_batch_size != 0, "Batch Size Error "
            #landmark_batch_size = int(np.ceil(config.BATCH_SIZE*landmark_radio))
            landmark_batch_size = int(config.BATCH_SIZE - pos_batch_size -
                                      part_batch_size - neg_batch_size)
            assert landmark_batch_size != 0, "Batch Size Error "
            batch_sizes = [
                pos_batch_size, part_batch_size, neg_batch_size,
                landmark_batch_size
            ]
            image_batch, label_batch, bbox_batch, landmark_batch = read_multi_tfrecords(
                dataset_dirs, batch_sizes, net)
        else:
            landmark_batch_size = 1
            neg_batch_size = int(config.BATCH_SIZE - pos_batch_size -
                                 part_batch_size)
            assert neg_batch_size != 0, "Batch Size Error "
            batch_sizes = [
                pos_batch_size, part_batch_size, neg_batch_size,
                landmark_batch_size
            ]
            image_batch, label_batch, bbox_batch = read_multi_tfrecords(
                dataset_dirs, batch_sizes, net)

    #landmark_dir
    if net == 'PNet':
        image_size = 12
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 0.5
    elif net == 'RNet':
        image_size = 24
        #radio_cls_loss = 1.0;radio_bbox_loss = 0.5;radio_landmark_loss = 0.5
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 1.0
    else:
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 1.0
        image_size = 48

    #define placeholder
    input_image = tf.placeholder(
        tf.float32,
        shape=[config.BATCH_SIZE, image_size, image_size, 3],
        name='input_image')
    label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label')
    bbox_target = tf.placeholder(tf.float32,
                                 shape=[config.BATCH_SIZE, 4],
                                 name='bbox_target')
    landmark_target = tf.placeholder(tf.float32,
                                     shape=[config.BATCH_SIZE, 10],
                                     name='landmark_target')
    #class,regression
    #with tf.device(gpu_ctx):
    if train_face:
        cls_loss_op, bbox_loss_op, landmark_loss_op, L2_loss_op, accuracy_op = net_factory(
            input_image, label, bbox_target, landmark_target, training=True)
    else:
        cls_loss_op, bbox_loss_op, landmark_loss_op, L2_loss_op, accuracy_op = net_factory(
            input_image, label, bbox_target, training=True)
    #train,update learning rate(3 loss)
    train_op, lr_op = train_model(
        base_lr,
        radio_cls_loss * cls_loss_op + radio_bbox_loss * bbox_loss_op +
        radio_landmark_loss * landmark_loss_op + L2_loss_op, num)
    # init
    #init = tf.global_variables_initializer()
    #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.7)
    #tf_config.gpu_options = gpu_options
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    tf_config.log_device_placement = False
    sess = tf.Session(config=tf_config)
    #sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    #save model
    saver = tf.train.Saver(max_to_keep=0)
    #sess.run(init)
    #load pretrained parameters
    if load_epoch:
        #check whether the dictionary is valid
        model_path = "%s-%s" % (prefix, str(load_epoch))
        model_dict = '/'.join(model_path.split('/')[:-1])
        ckpt = tf.train.get_checkpoint_state(model_dict)
        print("restore model path:", model_path)
        readstate = ckpt and ckpt.model_checkpoint_path
        #assert  readstate, "the params dictionary is not valid"
        saver.restore(sess, model_path)
        print("restore models' param")
    else:
        init = tf.global_variables_initializer()
        sess.run(init)
        print("init models using gloable: init")

    #visualize some variables
    tf.summary.scalar("cls_loss", cls_loss_op)  #cls_loss
    tf.summary.scalar("bbox_loss", bbox_loss_op)  #bbox_loss
    tf.summary.scalar("landmark_loss", landmark_loss_op)  #landmark_loss
    tf.summary.scalar("cls_accuracy", accuracy_op)  #cls_acc
    summary_op = tf.summary.merge_all()
    logs_dir = "../logs/%s" % (net)
    if os.path.exists(logs_dir) == False:
        os.mkdir(logs_dir)
    writer = tf.summary.FileWriter(logs_dir, sess.graph)
    #begin
    coord = tf.train.Coordinator()
    #begin enqueue thread
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    i = 0
    #total steps
    MAX_STEP = int(num / config.BATCH_SIZE + 1) * end_epoch
    epoch = 0
    sess.graph.finalize()
    save_acc = 0
    L2_loss = 0
    model_dict = '/'.join(prefix.split('/')[:-1])
    log_r_file = os.path.join(model_dict, "train_record.txt")
    print("model record is ", log_r_file)
    record_file_out = open(log_r_file, 'w')
    #record_file_out = open("train_record.txt",'w')
    try:
        for step in range(MAX_STEP):
            i = i + 1
            if coord.should_stop():
                break
            if train_face:
                if net == 'PNet':
                    image_batch_array, label_batch_array, bbox_batch_array = sess.run(
                        [image_batch, label_batch, bbox_batch])
                else:
                    image_batch_array, label_batch_array, bbox_batch_array, landmark_batch_array = sess.run(
                        [image_batch, label_batch, bbox_batch, landmark_batch])
            else:
                image_batch_array, label_batch_array, bbox_batch_array = sess.run(
                    [image_batch, label_batch, bbox_batch])
            #print("shape:  ",i, np.shape(image_batch_array),np.shape(label_batch_array),np.shape(bbox_batch_array))
            #random flip
            if train_face and not (net == 'PNet'):
                image_batch_array, landmark_batch_array = random_flip_images(
                    image_batch_array, label_batch_array, landmark_batch_array)
            '''
            print image_batch_array.shape
            print label_batch_array.shape
            print bbox_batch_array.shape
            print landmark_batch_array.shape
            print label_batch_array[0]
            print bbox_batch_array[0]
            print landmark_batch_array[0]
            '''
            if train_face and not (net == 'PNet'):
                _, _, summary = sess.run(
                    [train_op, lr_op, summary_op],
                    feed_dict={
                        input_image: image_batch_array,
                        label: label_batch_array,
                        bbox_target: bbox_batch_array,
                        landmark_target: landmark_batch_array
                    })
            else:
                _, _, summary = sess.run(
                    [train_op, lr_op, summary_op],
                    feed_dict={
                        input_image: image_batch_array,
                        label: label_batch_array,
                        bbox_target: bbox_batch_array
                    })

            if (step + 1) % display == 0:
                #acc = accuracy(cls_pred, labels_batch)
                if train_face and not (net == 'PNet'):
                    cls_loss, bbox_loss,landmark_loss,L2_loss,lr,acc = sess.run([cls_loss_op, bbox_loss_op,landmark_loss_op,L2_loss_op,lr_op,accuracy_op],\
                                                             feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array, landmark_target: landmark_batch_array})
                    print(
                        "%s : Step: %d, accuracy: %3f, cls loss: %4f, bbox loss: %4f, landmark loss: %4f,L2 loss: %4f,lr:%f "
                        % (datetime.now(), step + 1, acc, cls_loss, bbox_loss,
                           landmark_loss, L2_loss, lr))
                else:
                    cls_loss, bbox_loss,L2_loss,lr,acc = sess.run([cls_loss_op, bbox_loss_op,L2_loss_op,lr_op,accuracy_op],\
                                                             feed_dict={input_image: image_batch_array, label: label_batch_array, bbox_target: bbox_batch_array})
                    print(
                        "%s : Step: %d, accuracy: %3f, cls loss: %4f, bbox loss: %4f, L2 loss: %4f,lr:%f "
                        % (datetime.now(), step + 1, acc, cls_loss, bbox_loss,
                           L2_loss, lr))
            #save every two epochs
            if i * config.BATCH_SIZE > num * 10:
                epoch = epoch + 1
                i = 0
                #if save_acc < L2_loss:
                saver.save(sess, prefix, global_step=epoch * 100)
                save_acc = L2_loss
                if train_face and not (net == 'PNet'):
                    print(
                        "%s : Step: %d, accuracy: %3f, cls loss: %4f, bbox loss: %4f, landmark loss: %4f,L2 loss: %4f,lr:%f "
                        % (datetime.now(), step + 1, acc, cls_loss, bbox_loss,
                           landmark_loss, L2_loss, lr))
                    record_file_out.write(
                        "%s : epoch: %d, accuracy: %3f, cls loss: %4f, bbox loss: %4f, landmark_loss: %4f,lr:%f \n"
                        % (datetime.now(), epoch * 100, acc, cls_loss,
                           bbox_loss, landmark_loss, lr))
                else:
                    print(
                        "%s : Step: %d, accuracy: %3f, cls loss: %4f, bbox loss: %4f, L2 loss: %4f,lr:%f "
                        % (datetime.now(), step + 1, acc, cls_loss, bbox_loss,
                           L2_loss, lr))
                    record_file_out.write(
                        "%s : epoch: %d, accuracy: %3f, cls loss: %4f, bbox loss: %4f, L2_loss: %4f,lr:%f \n"
                        % (datetime.now(), epoch * 100, acc, cls_loss,
                           bbox_loss, L2_loss, lr))
                print("model saved over ", save_acc)
            writer.add_summary(summary, global_step=step)
    except tf.errors.OutOfRangeError:
        print("Over!!!")
    finally:
        coord.request_stop()
        writer.close()
    coord.join(threads)
    record_file_out.close()
    sess.close()
Ejemplo n.º 7
0
def train(net_factory, prefix, end_epoch, display=200, base_lr=0.01):
    """
    train PNet/RNet/ONet
    :param net_factory:
    :param prefix:
    :param end_epoch:16
    :param dataset:
    :param display:
    :param base_lr:
    :return:
    """
    mylist = [1]
    l = np.array(mylist)
    config = singleton.configuration._instance.config
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = config.train_gpu
    num = config.num

    pos_dir = os.path.join(config.train_pos_record)
    neg_dir = os.path.join(config.train_neg_record)

    pos_dir_val = os.path.join(config.val_pos_record)
    neg_dir_val = os.path.join(config.val_neg_record)

    dataset_dirs = [
        pos_dir, neg_dir
    ]  #array containing the directories of the different tfrecord files
    dataset_dirs_val = [pos_dir_val, neg_dir_val]
    pos_radio = config.pos_radio
    neg_radio = config.neg_radio

    pos_batch_size = int(
        np.ceil(config.BATCH_SIZE *
                pos_radio))  #specifying how many positives in the batch
    assert pos_batch_size != 0, "Batch Size Error "

    #part_batch_size = int(np.ceil(config.BATCH_SIZE*part_radio))
    #assert part_batch_size != 0,"Batch Size Error "
    #
    neg_batch_size = int(
        np.floor(config.BATCH_SIZE *
                 neg_radio))  #specifying how many negatives in the batch
    assert neg_batch_size != 0, "Batch Size Error "

    #landmark_batch_size = int(np.ceil(config.BATCH_SIZE*landmark_radio))    #specifying how many landmarks in the batch
    #assert landmark_batch_size != 0,"Batch Size Error "

    batch_sizes = [
        pos_batch_size, neg_batch_size
    ]  #array of the distribution of pos,neg, landmarks in the batch

    # batch_size number of images,labels,bbox,and landmarks from different record files (pos,neg,landmarks). Distribution of image weights is preserved
    image_batch, label_batch, filename_batch = read_tfrecord_v2.read_multi_tfrecords(
        dataset_dirs, batch_sizes)

    image_batch_val, label_batch_val, filename_batch_val = read_tfrecord_v2.read_multi_tfrecords(
        dataset_dirs_val, [100, 100])

    radio_cls_loss = config.radio_cls_loss

    image_size = config.image_size

    #define placeholders
    input_image = tf.placeholder(
        tf.float32,
        shape=[None, image_size, image_size, config.input_channels],
        name='input_image')
    label = tf.placeholder(tf.float32, shape=[None, 2], name='label')

    #class,regression
    # get initial losses
    #cls_loss_op,L2_loss_op,accuracy_op,cls_prob_op = net_factory(input_image, label,training=True)
    cls_loss_op, L2_loss_op, cls_prob_op = net_factory(input_image,
                                                       label,
                                                       training=True)
    #train,update learning rate(3 loss)
    train_op, lr_op = train_model(base_lr,
                                  radio_cls_loss * cls_loss_op + L2_loss_op,
                                  num)
    # init
    init = tf.global_variables_initializer()

    total_parameters = 0
    config1 = tf.ConfigProto()
    config1.gpu_options.allow_growth = True

    with tf.Session(config=config1) as sess:
        sess.run(tf.global_variables_initializer())
        #save model
        #saver = tf.train.Saver(tf.global_variables())
        saver = tf.train.Saver(max_to_keep=1000000)
        #sess.run(init)
        #visualize some variables
        tf.summary.scalar("cls_loss", cls_loss_op)  #cls_loss
        #tf.summary.scalar("cls_accuracy",accuracy_op)#cls_acc
        summary_op = tf.summary.merge_all()
        logs_dir = "/home/wassimea/Desktop/wassimea/work/train_models/mn/logs"
        if os.path.exists(logs_dir) == False:
            os.mkdir(logs_dir)
        writer = tf.summary.FileWriter(logs_dir, sess.graph)
        #begin
        coord = tf.train.Coordinator()
        #begin enqueue thread
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        i = 0
        #total steps
        MAX_STEP = int(num / config.BATCH_SIZE + 1) * end_epoch
        epoch = 0
        sess.graph.finalize()
        try:
            for step in range(MAX_STEP):
                try:
                    i = i + 1
                    if coord.should_stop():
                        print("coord must stop")
                        break
                    image_batch_array, label_batch_array, filename_batch_array = sess.run(
                        [image_batch, label_batch, filename_batch])
                    p = np.random.permutation(len(filename_batch_array))
                    image_batch_array = image_batch_array[p]
                    label_batch_array = label_batch_array[p]
                    filename_batch_array = filename_batch_array[p]
                    #random flip
                    #image_batch_array,landmark_batch_array = random_flip_images(image_batch_array,label_batch_array,landmark_batch_array)
                    '''
                    print image_batch_array.shape
                    print label_batch_array.shape
                    print bbox_batch_array.shape
                    print landmark_batch_array.shape
                    print label_batch_array[0]
                    print bbox_batch_array[0]
                    print landmark_batch_array[0]
                    '''
                    _, _, summary = sess.run([train_op, lr_op, summary_op],
                                             feed_dict={
                                                 input_image:
                                                 image_batch_array,
                                                 label: label_batch_array
                                             })
                    z = (step + 1) % display
                    if (step + 1) % display == 0:

                        image_batch_array_val, label_batch_array_val, filename_batch_array_val = sess.run(
                            [
                                image_batch_val, label_batch_val,
                                filename_batch_val
                            ])
                        #acc = accuracy(cls_pred, labels_batch)
                        #han
                        cls_loss, L2_loss, lr, cls_prob = sess.run(
                            [cls_loss_op, L2_loss_op, lr_op, cls_prob_op],
                            feed_dict={
                                input_image: image_batch_array_val,
                                label: label_batch_array_val
                            })
                        print(
                            "%s : Step: %d, cls loss: %4f,L2 loss: %4f,lr:%f "
                            %
                            (datetime.now(), step + 1, cls_loss, L2_loss, lr))
                    #save every two epochs
                    if i * config.BATCH_SIZE > num:
                        if (len(l) == 1):
                            l = filename_batch_array_val
                        else:
                            new = 0
                            for new_arr_val in filename_batch_array_val:
                                contained = False
                                for old_arr_val in l:
                                    if new_arr_val == old_arr_val:
                                        contained = True
                                if (contained == False):
                                    new += 1
                            print("New eval images: ", new)
                            x = filename_batch_array_val
                        total_pos = 0
                        total_neg = 0
                        for i in range(200):
                            #pred = cls_prob[i][1]
                            posval = label_batch_array_val[i][1]
                            negval = label_batch_array_val[i][0]
                            ind = np.argmax(cls_prob[i])
                            if (ind == 1 and posval == 1.0):
                                total_pos += 1
                            if (ind == 0 and negval == 1.0):
                                total_neg += 1
                        epoch = epoch + 1
                        i = 0
                        posacc = total_pos / 100
                        negacc = total_neg / 100

                        summpos = tf.Summary()
                        summpos.value.add(tag="posacc", simple_value=posacc)
                        writer.add_summary(summpos, global_step=epoch *
                                           2)  #act as global_step

                        summneg = tf.Summary()
                        summneg.value.add(tag="negacc", simple_value=negacc)
                        writer.add_summary(summneg, global_step=epoch *
                                           2)  #act as global_step
                        saver.save(sess, prefix, global_step=epoch * 2)
                        #tf.train.write_graph(sess.graph.as_graph_def(), prefix, 'tensorflowModel.pb', False)
                        print("Finished epocch ------- Posacc: " +
                              str(posacc) + "-----Negacc: " + str(negacc))
                    writer.add_summary(summary, global_step=step)
                except EnvironmentError as error:
                    print(error)
                    y = 1
        except tf.errors.OutOfRangeError:
            print("完成!!!")
        finally:
            coord.request_stop()
            print("Before writing")
            writer.close()
        coord.join(threads)
        sess.close()
Ejemplo n.º 8
0
def train(net_factory,
          prefix,
          end_epoch,
          base_dir,
          display=200,
          base_lr=0.01,
          COLOR_GRAY=0):  ###
    """
    train PNet/RNet/ONet
    :param net_factory:
    :param prefix:
    :param end_epoch:16
    :param dataset:
    :param display:
    :param base_lr:
    :return:
    """
    net = prefix.split('/')[-1]
    #label file
    #label_file = os.path.join(base_dir,'train_%s_32_landmarkcolor.txt' % net)
    #label_file = os.path.join(base_dir,'landmark_12_few.txt')
    #print (label_file )
    #f = open(label_file, 'r')
    num = 1600000  #531806  #950000   # 1500000
    #num = len(f.readlines())
    #f.close()
    print("Total datasets is: ", num)
    print(prefix)

    if net == 'PNet' and not config.SINGLEF:
        if COLOR_GRAY == 0:
            #dataset_dir = os.path.join(base_dir,'train_%s_ALL.tfrecord_shuffle' % net)
            dataset_dir = os.path.join(
                base_dir, 'train_%s_12_color.tfrecord_shuffle' % net)
        elif COLOR_GRAY == 1:
            dataset_dir = os.path.join(
                base_dir, 'train_%s_12_gray.tfrecord_shuffle' % net)
        print(dataset_dir)
        #pdb.set_trace()
        image_batch, label_batch, bbox_batch, landmark_batch = read_single_tfrecord(
            dataset_dir, config.BATCH_SIZE, net, COLOR_GRAY)
    elif net == 'PNet' and config.SINGLEF:
        if COLOR_GRAY == 0:
            pos_dir = os.path.join(base_dir,
                                   'PNet_12_color_pos.tfrecord_shuffle')
            part_dir = os.path.join(base_dir,
                                    'PNet_12_color_part.tfrecord_shuffle')
            neg_dir = os.path.join(base_dir,
                                   'PNet_12_color_neg.tfrecord_shuffle')
        elif COLOR_GRAY == 1:
            pos_dir = os.path.join(base_dir,
                                   'PNet_12_gray_pos.tfrecord_shuffle')
            part_dir = os.path.join(base_dir,
                                    'PNet_12_gray_part.tfrecord_shuffle')
            neg_dir = os.path.join(base_dir,
                                   'PNet_12_gray_neg.tfrecord_shuffle')
        landmark_dir = None
        dataset_dirs = [pos_dir, part_dir, neg_dir, landmark_dir]

        pos_radio = 1.0 / 5
        part_radio = 1.0 / 5
        landmark_radio = 1.0 / 6
        neg_radio = 3.0 / 5
        pos_batch_size = int(np.ceil(config.BATCH_SIZE * pos_radio))
        assert pos_batch_size != 0, "Batch Size Error "
        part_batch_size = int(np.ceil(config.BATCH_SIZE * part_radio))
        assert part_batch_size != 0, "Batch Size Error "
        neg_batch_size = int(np.ceil(config.BATCH_SIZE * neg_radio))
        assert neg_batch_size != 0, "Batch Size Error "
        landmark_batch_size = int(np.ceil(config.BATCH_SIZE * landmark_radio))
        assert landmark_batch_size != 0, "Batch Size Error "
        batch_sizes = [
            pos_batch_size, part_batch_size, neg_batch_size,
            landmark_batch_size
        ]
        landmarkflag = 0
        partflag = 1
        image_batch, label_batch, bbox_batch, landmark_batch = read_multi_tfrecords(
            dataset_dirs, batch_sizes, net, landmarkflag, partflag, COLOR_GRAY)
    elif net == 'RNet':
        if COLOR_GRAY == 0:
            pos_dir = os.path.join(base_dir,
                                   'RNet_24_color_pos.tfrecord_shuffle')
            part_dir = os.path.join(base_dir,
                                    'RNet_24_color_part.tfrecord_shuffle')
            neg_dir = os.path.join(base_dir,
                                   'RNet_24_color_neg.tfrecord_shuffle')
            #landmark_dir = os.path.join(base_dir,'RNet_24_color_landmark.tfrecord_shuffle')
        elif COLOR_GRAY == 1:
            pos_dir = os.path.join(base_dir,
                                   'RNet_24_gray_pos.tfrecord_shuffle')
            part_dir = os.path.join(base_dir,
                                    'RNet_24_gray_part.tfrecord_shuffle')
            neg_dir = os.path.join(base_dir,
                                   'RNet_24_gray_neg.tfrecord_shuffle')
            #landmark_dir = os.path.join(base_dir,'RNet_24_gray_landmark.tfrecord_shuffle')
        landmark_dir = None
        dataset_dirs = [pos_dir, part_dir, neg_dir, landmark_dir]
        pos_radio = 1.0 / 5
        part_radio = 1.0 / 5
        landmark_radio = 1.0 / 6
        neg_radio = 3.0 / 5
        pos_batch_size = int(np.ceil(config.BATCH_SIZE * pos_radio))
        assert pos_batch_size != 0, "Batch Size Error "
        part_batch_size = int(np.ceil(config.BATCH_SIZE * part_radio))
        assert part_batch_size != 0, "Batch Size Error "
        neg_batch_size = int(np.ceil(config.BATCH_SIZE * neg_radio))
        assert neg_batch_size != 0, "Batch Size Error "
        landmark_batch_size = int(np.ceil(config.BATCH_SIZE * landmark_radio))
        assert landmark_batch_size != 0, "Batch Size Error "
        batch_sizes = [
            pos_batch_size, part_batch_size, neg_batch_size,
            landmark_batch_size
        ]
        landmarkflag = 0  #  select landmarkflag
        partflag = 1
        image_batch, label_batch, bbox_batch, landmark_batch = read_multi_tfrecords(
            dataset_dirs, batch_sizes, net, landmarkflag, partflag, COLOR_GRAY)
    elif net == 'ONet':
        if COLOR_GRAY == 0:
            pos_dir = os.path.join(base_dir,
                                   'ONet_48_color_pos.tfrecord_shuffle')
            part_dir = os.path.join(base_dir,
                                    'ONet_48_color_part.tfrecord_shuffle')
            neg_dir = os.path.join(base_dir,
                                   'ONet_48_color_neg.tfrecord_shuffle')
            #landmark_dir = os.path.join(base_dir,'ONet_48_color_landmark.tfrecord_shuffle')
        elif COLOR_GRAY == 1:
            pos_dir = os.path.join(base_dir,
                                   'ONet_48_gray_pos.tfrecord_shuffle')
            #part_dir = os.path.join(base_dir,'ONet_48_gray_part.tfrecord_shuffle')
            neg_dir = os.path.join(base_dir,
                                   'ONet_48_gray_neg_single.tfrecord_shuffle')
            #landmark_dir = os.path.join(base_dir,'ONet_48_gray_landmark.tfrecord_shuffle')
        part_dir = None
        landmark_dir = None
        #part_dir = None
        dataset_dirs = [pos_dir, part_dir, neg_dir, landmark_dir]
        pos_radio = 1.0 / 2
        part_radio = 1.0 / 6
        landmark_radio = 1.0 / 6
        neg_radio = 1.0 / 2
        pos_batch_size = int(np.ceil(config.BATCH_SIZE * pos_radio))
        assert pos_batch_size != 0, "Batch Size Error "
        part_batch_size = int(np.ceil(config.BATCH_SIZE * part_radio))
        assert part_batch_size != 0, "Batch Size Error "
        neg_batch_size = int(np.ceil(config.BATCH_SIZE * neg_radio))
        assert neg_batch_size != 0, "Batch Size Error "
        landmark_batch_size = int(np.ceil(config.BATCH_SIZE * landmark_radio))
        assert landmark_batch_size != 0, "Batch Size Error "
        batch_sizes = [
            pos_batch_size, part_batch_size, neg_batch_size,
            landmark_batch_size
        ]
        landmarkflag = 0
        partflag = 0
        image_batch, label_batch, bbox_batch, landmark_batch = read_multi_tfrecords(
            dataset_dirs, batch_sizes, net, landmarkflag, partflag, COLOR_GRAY)

    #landmark_dir
    if net == 'PNet':
        image_size = 12
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 0.5
    elif net == 'RNet':
        image_size = 24
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 0.5
    else:
        radio_cls_loss = 1.0
        radio_bbox_loss = 0
        radio_landmark_loss = 0
        image_size = 48

    #define placeholder
    if COLOR_GRAY == 1:
        input_image = tf.placeholder(
            tf.float32,
            shape=[config.BATCH_SIZE, image_size, image_size, 1],
            name='input_image')
    else:
        input_image = tf.placeholder(
            tf.float32,
            shape=[config.BATCH_SIZE, image_size, image_size, 3],
            name='input_image')
    label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label')
    bbox_target = tf.placeholder(tf.float32,
                                 shape=[config.BATCH_SIZE, 4],
                                 name='bbox_target')
    landmark_target = tf.placeholder(tf.float32,
                                     shape=[config.BATCH_SIZE, 4],
                                     name='landmark_target')
    #class,regression
    print('class,regression+')
    '''
    cls_loss_op,bbox_loss_op,landmark_loss_op,L2_loss_op,accuracy_op = net_factory(
        input_image, label, bbox_target,landmark_target,training=True)
    '''
    cls_loss_op, L2_loss_op, accuracy_op = net_factory(input_image,
                                                       label,
                                                       bbox_target,
                                                       landmark_target,
                                                       training=True)
    #train,update learning rate(3 loss)
    # train_op, lr_op,global_step = train_model(base_lr, radio_cls_loss*cls_loss_op + radio_bbox_loss*bbox_loss_op + radio_landmark_loss*landmark_loss_op + L2_loss_op, num)
    train_op, lr_op, global_step = train_model(
        base_lr, radio_cls_loss * cls_loss_op + L2_loss_op, num)
    # init
    init = tf.global_variables_initializer()
    ###gpu
    #configp = tf.ConfigProto()
    #configp.allow_soft_placement = True
    #configp.gpu_options.per_process_gpu_memory_fraction = 0.3
    #configp.gpu_options.allow_growth = True
    #sess = tf.Session(config =configp)

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.3)
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                            log_device_placement=False))
    #save model
    saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=3)
    #saver = tf.train.Saver(max_to_keep=3)
    # restore model
    if net == 'PNet':
        if COLOR_GRAY == 1:
            pretrained_model = './data/new_model/PNet_merge_gray'
        else:
            pretrained_model = './data/new_model/PNet_merge_color'
    elif net == 'RNet':
        if COLOR_GRAY == 1:
            pretrained_model = './data/new_model/RNet_cmerge_gray_gray'
        else:
            pretrained_model = './data/new_model/RNet_cmerge_gray_color'
    elif net == 'ONet':
        pretrained_model = './data/new_model/test1-1ONet_NIR_calib_A_gray'
    else:
        pretrained_model = None

    sess.run(init)
    print(sess.run(global_step))
    if pretrained_model and config.PRETRAIN:
        print('Restoring pretrained model: %s' % pretrained_model)
        ckpt = tf.train.get_checkpoint_state(pretrained_model)
        if ckpt and ckpt.model_checkpoint_path:
            print(ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            print('Not Pretrain \n')

    print(sess.run(global_step))
    #visualize some variables
    tf.summary.scalar("cls_loss", cls_loss_op)  #cls_loss
    #tf.summary.scalar("bbox_loss",bbox_loss_op)#bbox_loss
    # tf.summary.scalar("landmark_loss",landmark_loss_op)#landmark_loss
    tf.summary.scalar("cls_accuracy", accuracy_op)  #cls_acc
    summary_op = tf.summary.merge_all()
    logs_dir = "./logs/%s" % (net)
    if os.path.exists(logs_dir) == False:
        os.mkdir(logs_dir)
    writer = tf.summary.FileWriter(logs_dir, sess.graph)
    #begin
    coord = tf.train.Coordinator()
    #begin enqueue thread
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    i = 0
    #total steps
    MAX_STEP = int(num / config.BATCH_SIZE + 1) * end_epoch

    epoch = 0
    sess.graph.finalize()
    try:
        for step in range(MAX_STEP):
            i = i + 1
            if coord.should_stop():
                break
            image_batch_array, label_batch_array, bbox_batch_array, landmark_batch_array = sess.run(
                [image_batch, label_batch, bbox_batch, landmark_batch])
            #random flip
            #image_batch_array,landmark_batch_array = random_flip_images(image_batch_array,label_batch_array,landmark_batch_array)
            '''
            print image_batch_array.shape
            print label_batch_array.shape
            print bbox_batch_array.shape
            print landmark_batch_array.shape
            print label_batch_array[0]
            print bbox_batch_array[0]
            print landmark_batch_array[0]
            '''
            _, _, summary = sess.run(
                [train_op, lr_op, summary_op],
                feed_dict={
                    input_image: image_batch_array,
                    label: label_batch_array,
                    bbox_target: bbox_batch_array,
                    landmark_target: landmark_batch_array
                })

            if (step + 1) % display == 0:

                #acc = accuracy(cls_pred, labels_batch)
                cls_loss, L2_loss, lr, acc = sess.run(
                    [cls_loss_op, L2_loss_op, lr_op, accuracy_op],
                    feed_dict={
                        input_image: image_batch_array,
                        label: label_batch_array,
                        bbox_target: bbox_batch_array,
                        landmark_target: landmark_batch_array
                    })
                print(
                    "%s : Step: %d, accuracy: %3f, cls loss: %4f, L2 loss: %4f,lr:%f "
                    % (datetime.now(), step + 1, acc, cls_loss, L2_loss, lr))

            #save every two epochs
            if i * config.BATCH_SIZE > num:
                epoch = epoch + 1
                print('save epoch%d' % epoch)
                i = 0
                saver.save(sess, prefix, global_step=epoch)
            writer.add_summary(summary, global_step=step)
    except tf.errors.OutOfRangeError:
        print("完成!!!")
    finally:
        coord.request_stop()
        writer.close()
    coord.join(threads)
    sess.close()
Ejemplo n.º 9
0
def train(net_factory, prefix, end_epoch, base_dir, display=200, base_lr=0.01):
    net = prefix.split('/')[-1]
    label_file = os.path.join(base_dir, 'train_%s_landmark.txt' % net)
    f = open(label_file, 'r')
    num = len(f.readlines())
    if net == 'PNet':
        dataset_dir = os.path.join(base_dir,
                                   'train_%s_landmark.tfrecord_shuffle' % net)
        image_batch, label_batch, bbox_batch, landmark_batch = read_single_tfrecord(
            dataset_dir, config.BATCH_SIZE, net)
    # 其他网络读取3个文件
    else:
        pos_dir = os.path.join(base_dir, 'pos_landmark.tfrecord_shuffle')
        part_dir = os.path.join(base_dir, 'part_landmark.tfrecord_shuffle')
        neg_dir = os.path.join(base_dir, 'neg_landmark.tfrecord_shuffle')
        landmark_dir = os.path.join(base_dir,
                                    'landmark_landmark.tfrecord_shuffle')
        dataset_dirs = [pos_dir, part_dir, neg_dir, landmark_dir]
        pos_radio = 1.0 / 6
        part_radio = 1.0 / 6
        landmark_radio = 1.0 / 6
        neg_radio = 3.0 / 6
        pos_batch_size = int(np.ceil(config.BATCH_SIZE * pos_radio))
        part_batch_size = int(np.ceil(config.BATCH_SIZE * part_radio))
        neg_batch_size = int(np.ceil(config.BATCH_SIZE * neg_radio))
        landmark_batch_size = int(np.ceil(config.BATCH_SIZE * landmark_radio))
        batch_sizes = [
            pos_batch_size, part_batch_size, neg_batch_size,
            landmark_batch_size
        ]
        image_batch, label_batch, bbox_batch, landmark_batch = read_multi_tfrecords(
            dataset_dirs, batch_sizes, net)
    #确定损失函数之间的比例
    if net == 'PNet':
        image_size = 12
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 0.5
    elif net == 'RNet':
        image_size = 24
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 0.5
    else:
        radio_cls_loss = 1.0
        radio_bbox_loss = 0.5
        radio_landmark_loss = 1.0
        image_size = 48
    input_image = tf.placeholder(
        tf.float32,
        shape=[config.BATCH_SIZE, image_size, image_size, 3],
        name='input_image')
    label = tf.placeholder(tf.float32, shape=[config.BATCH_SIZE], name='label')
    bbox_target = tf.placeholder(tf.float32,
                                 shape=[config.BATCH_SIZE, 4],
                                 name='bbox_target')
    landmark_target = tf.placeholder(tf.float32,
                                     shape=[config.BATCH_SIZE, 10],
                                     name='labdmark_target')
    # 得到分类和回归结果
    cls_loss_op, bbox_loss_op, landmark_loss_op, L2_loss_op, accuracy_op = \
        net_factory(input_image, label, bbox_target,
        landmark_target, training=True)
    # 训练优化
    train_op, lr_op = train_model(
        base_lr,
        radio_cls_loss * cls_loss_op + radio_bbox_loss * bbox_loss_op +
        radio_landmark_loss * landmark_loss_op + L2_loss_op, num)
    # 进行初始化
    init = tf.global_variables_initializer()
    sess = tf.Session()
    # 保存模型
    saver = tf.train.Saver(max_to_keep=0)
    sess.run(init)
    #开始定义同步
    coord = tf.train.Coordinator()
    #开始队列线程
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    i = 0
    #总的训练步数
    MAX_STEP = int(num / config.BATCH_SIZE + 1) * end_epoch
    epoch = 0
    try:
        for step in range(MAX_STEP):
            i = i + 1
            if coord.should_stop():
                break
            image_batch_array, label_batch_array, bbox_batch_array, landmark_batch_array = sess.run(
                [image_batch, label_batch, bbox_batch, landmark_batch])
            # random flip
            image_batch_array, landmark_batch_array = random_flip_images(
                image_batch_array, label_batch_array, landmark_batch_array)
            _, _ = sess.run(
                [train_op, lr_op],
                feed_dict={
                    input_image: image_batch_array,
                    label: label_batch_array,
                    bbox_target: bbox_batch_array,
                    landmark_target: landmark_batch_array
                })

            if (step + 1) % display == 0:
                # acc = accuracy(cls_pred, labels_batch)
                cls_loss, bbox_loss, landmark_loss, L2_loss, lr, acc = sess.run(
                    [
                        cls_loss_op, bbox_loss_op, landmark_loss_op,
                        L2_loss_op, lr_op, accuracy_op
                    ],
                    feed_dict={
                        input_image: image_batch_array,
                        label: label_batch_array,
                        bbox_target: bbox_batch_array,
                        landmark_target: landmark_batch_array
                    })
                print(
                    "%s : Step: %d, accuracy: %3f, cls loss: %4f, bbox loss: %4f, landmark loss: %4f,L2 loss: %4f,lr:%f "
                    % (datetime.now(), step + 1, acc, cls_loss, bbox_loss,
                       landmark_loss, L2_loss, lr))
            # save every two epochs
            if i * config.BATCH_SIZE > num * 2:
                epoch = epoch + 1
                i = 0
                saver.save(sess, prefix, global_step=epoch * 2)
    except tf.errors.OutOfRangeError:
        print("完成!!!")
    finally:
        coord.request_stop()
    coord.join(threads)
    sess.close()