Example #1
0
def gen_PNet_tfrecords(
    bbox_anno_file,
    bbox_im_dir,
    save_dir,
    landmark_anno_file,
    landmark_im_dir,
    tfrecords_output_dir,
    debug=False):
    
    size = 12
    net = "PNet"
    
    # pos_list_file, neg_list_file, part_list_file = files
    files = gen_PNet_bbox_data(bbox_anno_file, bbox_im_dir, save_dir, debug=debug)
    _,_,landmark_list_file = GenLandmarkData(landmark_anno_file, landmark_im_dir, \
                            net, save_dir, argument=True,debug=debug)

    with open(files[0], 'r') as f:
        pos = f.readlines()

    with open(files[1], 'r') as f:
        neg = f.readlines()

    with open(files[2], 'r') as f:
        part = f.readlines()

    with open(landmark_list_file, 'r') as f:
        landmark = f.readlines()
        
    nums = [len(neg), len(pos), len(part)]
    # ratio = [3, 1, 1]
    base_num = 250000
    base_num = min([int(len(neg)/3),len(pos), len(part), base_num])
    print(len(neg), len(pos), len(part), base_num)
    if len(neg) > base_num * 3:
        neg_keep = np.random.choice(len(neg), size=base_num * 3, replace=True)
    else:
        neg_keep = np.random.choice(len(neg), size=len(neg), replace=True)
    pos_keep = np.random.choice(len(pos), size=base_num, replace=True)
    part_keep = np.random.choice(len(part), size=base_num, replace=True)
    print(len(neg_keep), len(pos_keep), len(part_keep))
    imagelist = []
    for i in pos_keep:
        imagelist.append(pos[i])
    for i in neg_keep:
        imagelist.append(neg[i])
    for i in part_keep:
        imagelist.append(part[i])
    for item in landmark:
        imagelist.append(item)
    
    if not os.path.exists(tfrecords_output_dir):
        os.makedirs(tfrecords_output_dir)
        
    f = open(os.path.join(tfrecords_output_dir,'train_{}.txt'.format(net)),'w')
    f.writelines(imagelist)
    f.close()
    
    tf_filename = os.path.join(tfrecords_output_dir, 'train_{}.tfrecord'.format(net))    
    run(imagelist, tf_filename, shuffling=True)
Example #2
0
def gen_ONet_tfrecords(
    bbox_anno_file,
    bbox_im_dir,
    save_dir,
    landmark_anno_file,
    tfrecords_output_dir,
    model_path,
    debug=False):
    
    size = 48
    net = 'ONet'
    
    # pos_list_file, neg_list_file, part_list_file = files
    files = gen_ONet_bbox_data(bbox_anno_file, bbox_im_dir, save_dir, model_path, debug=debug)
    _,_,landmark_list_file = GenLandmarkData(landmark_anno_file, net, size, save_dir,\
                                             argument=True,debug=debug)
    
    with open(files[0], 'r') as f:
        pos = f.readlines()

    with open(files[1], 'r') as f:
        neg = f.readlines()

    with open(files[2], 'r') as f:
        part = f.readlines()

    with open(landmark_list_file, 'r') as f:
        landmark = f.readlines()
        
    #write all data
    imageLists = [pos, neg, part, landmark]
    if not os.path.exists(tfrecords_output_dir):
        os.mkdir(tfrecords_output_dir)
        
    with open(os.path.join(tfrecords_output_dir, "train_{}.txt".format(net)), "w") as f:
        print(len(neg))
        print(len(pos))
        print(len(part))
        print(len(landmark))
        for i in np.arange(len(pos)):
            f.write(pos[i])
        for i in np.arange(len(neg)):
            f.write(neg[i])
        for i in np.arange(len(part)):
            f.write(part[i])
        for i in np.arange(len(landmark)):
            f.write(landmark[i])

    tf_filenames = [
        os.path.join(tfrecords_output_dir,'pos_landmark.tfrecord'),
        os.path.join(tfrecords_output_dir,'part_landmark.tfrecord'),
        os.path.join(tfrecords_output_dir,'neg_landmark.tfrecord'),
        os.path.join(tfrecords_output_dir,'landmark_landmark.tfrecord'),
    ]    
    
    for imgs, files in zip(imageLists, tf_filenames):
        run(imgs, files, shuffling=True)
Example #3
0
def gen_RNet_tfrecords(
    bbox_anno_file,
    bbox_im_dir,
    save_dir,
    landmark_anno_file,
    landmark_im_dir,
    tfrecords_output_dir,
    model_path,
    debug=False):
    
    size = 24
    net = "RNet"
    
    # pos_list_file, neg_list_file, part_list_file = files
    files = gen_RNet_bbox_data(bbox_anno_file, bbox_im_dir, save_dir, model_path, debug=debug)
    _,_,landmark_list_file = GenLandmarkData(landmark_anno_file, landmark_im_dir, \
                            net, save_dir, argument=True,debug=debug)
    files.append(landmark_list_file)
    write_tfrecords(files, tfrecords_output_dir, net)