def gen_onet_data(data_dir,
                  anno_file,
                  pnet_model_file,
                  rnet_model_file,
                  prefix_path='',
                  use_cuda=True,
                  vis=False):
    mtcnn_detector = MtcnnDetector(p_model_path=pnet_model_file,
                                   r_model_path=rnet_model_file,
                                   o_model_path=None,
                                   min_face_size=12,
                                   use_cuda=True)

    imagedb = ImageDB(anno_file, mode="test", prefix_path=prefix_path)
    imdb = imagedb.load_imdb()
    image_reader = TestImageLoader(imdb, 1, False)

    all_boxes = list()
    batch_idx = 0

    for databatch in image_reader:
        if batch_idx % 100 == 0:
            print("%d images done" % batch_idx)
        im = databatch
        t = time.time()
        # detect an image by pnet and rnet
        p_boxes, p_boxes_align = mtcnn_detector.detect_pnet(im=im)
        boxes, boxes_align = mtcnn_detector.detect_rnet(im=im,
                                                        dets=p_boxes_align)
        if boxes_align is None:
            all_boxes.append(np.array([]))
            batch_idx += 1
            continue
        if vis:
            vision.vis_face(im, boxes_align)

        t1 = time.time() - t
        print('time cost for image ', batch_idx, '/', image_reader.size, ': ',
              t1)
        all_boxes.append(boxes_align)
        batch_idx += 1

    save_path = config.TRAIN_DATA_DIR
    if not os.path.exists(save_path):
        os.mkdir(save_path)

    save_file = os.path.join(save_path,
                             "pnet_rnet_detections_%d.pkl" % int(time.time()))

    with open(save_file, 'wb') as f:
        pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)

    # save_file = '/home/liujing/Codes/MTCNN/data/pnet_detections_1532582821.pkl'
    get_onet_sample_data(data_dir, anno_file, save_file, prefix_path)
Esempio n. 2
0
def gen_rnet_data(data_dir,
                  anno_file,
                  pnet_model_file,
                  prefix_path='',
                  use_cuda=True,
                  vis=False):
    # load the pnet and pnet_detector ,利用刚刚训练的pnet网络生成rnet的训练数据
    mtcnn_detector = MtcnnDetector(
        p_model_path=pnet_model_file,  # pnet_model_file自行设置
        r_model_path=None,
        o_model_path=None,
        min_face_size=12,
        use_cuda=True)
    device = mtcnn_detector.device

    # 生成rnet网络的训练集
    imagedb = ImageDB(anno_file, mode="test", prefix_path=prefix_path)
    imdb = imagedb.load_imdb()
    image_reader = TestImageLoader(imdb, 1, False)

    all_boxes = []
    batch_idx = 0

    for databatch in image_reader:
        if batch_idx % 100 == 0:
            print("%d images done" % batch_idx)
        im = databatch
        t = time.time()
        boxes, boxes_align = mtcnn_detector.detect_pnet(im)
        if boxes_align is None:
            all_boxes.append(np.array([]))
            continue
        if vis:
            vision.vis_face(im, boxes_align)

        t1 = time.time() - t
        print('time cost for image {} / {} : {:.4f}'.format(
            batch_idx, image_reader.size, t1))
        all_boxes.append(boxes_align)
        batch_idx += 1

    save_path = config.TRAIN_DATA_DIR
    if not os.path.exists(save_path):
        os.mkdir(save_path)

    save_file = os.path.join(save_path,
                             "pnet_detections_%d.pkl" % int(time.time()))

    with open(save_file, 'wb') as f:
        pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)

    # save_file = '/home/liujing/Codes/MTCNN/data/pnet_detections_1532530263.pkl'
    get_rnet_sample_data(data_dir, anno_file, save_file, prefix_path)
Esempio n. 3
0
def gen_onet_data(data_dir, anno_file, pnet_model_file, rnet_model_file, prefix_path='', use_cuda=True, vis=False):
    
    pnet, rnet, _ = create_mtcnn_net(p_model_path=pnet_model_file, r_model_path=rnet_model_file, use_cuda=use_cuda)
    mtcnn_detector = MtcnnDetector(pnet=pnet, rnet=rnet, min_face_size=12)

    imagedb = ImageDB(anno_file, mode="test", prefix_path=prefix_path)
    imdb = imagedb.load_imdb()
    image_reader = TestImageLoader(imdb,1,False)

    all_boxes = list()
    batch_idx = 0
    
    for databatch in image_reader:
        if batch_idx % 100 == 0:
            print("%d images done" % batch_idx)
        im = databatch
        t = time.time()
        #detect an image by pnet and rnet
        p_boxes, p_boxes_align = mtcnn_detector.detect_pnet(im=im)
        boxes, boxes_align = mtcnn_detector.detect_rnet(im=im, dets=p_boxes_align)
        if boxes_align is None:
            all_boxes.append(np.array([]))
            batch_idx += 1
            continue
        if vis:
            rgb_im = cv2.cvtColor(np.asarray(im), cv2.COLOR_BGR2RGB)
            vision.vis_two(rgb_im, boxes, boxes_align)

        t1 = time.time() - t
        print 'time cost for image ', batch_idx, '/', image_reader.size, ': ', t1
        all_boxes.append(boxes_align)
        batch_idx += 1
        
    save_path = config.MODLE_STORE_DIR
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    
    save_file = os.path.join(save_path, "rnet_detections_%d.pkl" % int(time.time()))
    
    with open(save_file, 'wb') as f:
        cPickle.dump(all_boxes, f, cPickle.HIGHEST_PROTOCOL)
    
    
    #save_file = '/home/wujiyang/FaceProjects/MTCNN_TRAIN/model_store/rnet_detections_1527304558.pkl'
    get_onet_sample_data(data_dir, anno_file, save_file, prefix_path)
Esempio n. 4
0
def train_p_net(annotation_file,
                model_store_path,
                end_epoch=50,
                frequent=200,
                base_lr=0.01,
                batch_size=256,
                use_cuda=True):

    # initialize the PNet ,loss function and set optimization for this network
    if not os.path.exists(model_store_path):
        os.makedirs(model_store_path)
    net = PNet(is_train=True, use_cuda=use_cuda)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if use_cuda:
        net.to(device)
    lossfn = LossFn()
    optimizer = torch.optim.Adam(net.parameters(), lr=base_lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[10, 25, 40],
                                                     gamma=0.1)
    # load training image
    imagedb = ImageDB(annotation_file)
    gt_imdb = imagedb.load_imdb()
    gt_imdb = imagedb.append_flipped_images(gt_imdb)
    train_data = TrainImageReader(gt_imdb, 12, batch_size, shuffle=True)

    # train net
    net.train()
    for cur_epoch in range(end_epoch):
        scheduler.step()
        train_data.reset()  # shuffle the data for this epoch
        for batch_idx, (image, (gt_label, gt_bbox,
                                gt_landmark)) in enumerate(train_data):
            im_tensor = [
                image_tools.convert_image_to_tensor(image[i, :, :, :])
                for i in range(image.shape[0])
            ]
            im_tensor = torch.stack(im_tensor)

            gt_label = torch.from_numpy(gt_label).float()
            gt_bbox = torch.from_numpy(gt_bbox).float()
            # gt_landmark = torch.from_numpy(gt_landmark).float()
            if use_cuda:
                im_tensor = im_tensor.to(device)
                gt_label = gt_label.to(device)
                gt_bbox = gt_bbox.to(device)

            cls_pred, box_offset_pred = net(im_tensor)
            cls_loss = lossfn.cls_loss(gt_label, cls_pred)
            box_offset_loss = lossfn.box_loss(gt_label, gt_bbox,
                                              box_offset_pred)
            all_loss = cls_loss * 1.0 + box_offset_loss * 0.5

            if batch_idx % frequent == 0:
                accuracy = compute_accuracy(cls_pred, gt_label)
                print(
                    "[%s, Epoch: %d, Step: %d] accuracy: %.6f, all_loss: %.6f, cls_loss: %.6f, bbox_reg_loss: %.6f, lr: %.6f"
                    % (datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
                       cur_epoch + 1, batch_idx, accuracy.data.tolist(),
                       all_loss.data.tolist(), cls_loss.data.tolist(),
                       box_offset_loss.data.tolist(), scheduler.get_lr()[0]))

            optimizer.zero_grad()
            all_loss.backward()
            optimizer.step()

        # TODO: add validation set for trained model

        if (cur_epoch + 1) % 10 == 0:
            torch.save(
                net.state_dict(),
                os.path.join(model_store_path,
                             "pnet_model_epoch_%d.pt" % (cur_epoch + 1)))

    torch.save(net.state_dict(),
               os.path.join(model_store_path, 'pnet_nodel_final.pt'))