コード例 #1
0
ファイル: test.py プロジェクト: 124451/eye_copen_close_other
def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    data_trans = Transforms.Compose([
        # Transforms.Resize((24, 48)),
        Transforms.ToTensor(),
        Transforms.Normalize((0.407, 0.405, 0.412), (0.087, 0.087, 0.087)),
    ])
    test_data = mbhk_data(img_path,
                                       transform=data_trans)
    test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=False, num_workers=4)
    # 定义模型 输入图像24*48
    net = MixNet(input_size=(24,48), num_classes=3)
    weight_dict = torch.load("weight/change_mix_data_0202/Mixnet_epoch_79.pth")
    new_state_dict = OrderedDict()
    for k, v in weight_dict.items():
        name = k[7:]
        new_state_dict[name] = v

    net.load_state_dict(new_state_dict)
    # stat(net,(3,48,48))
    net.to(device)
    net.eval()
    acc = 0.0
    val_num = len(test_data)
    with torch.no_grad():
        for i, data in enumerate(test_data_loader):
            img,label,_ = data
            outputs = net(img.to(device))
            result = torch.max(outputs,1)[1]
            acc += (result == label.to(device)).sum().item()
        print("access:%.3f"%(acc/val_num))
コード例 #2
0
def draw_RR_curve():
    class_num = 3
    vaild_ttrans = Transforms.Compose([
        # Transforms.RandomVerticalFlip(p=0.5),
        # transforms.RandomRotation(30),
        # transforms.RandomCrop(100),
        # transforms.RandomResizedCrop(112),
        # Transforms.ColorJitter(brightness=0.5),
        # transforms.RandomErasing(p=0.2, scale=(0.02, 0.03), ratio=(0.3, 0.3), value=0, ),
        Transforms.Resize((24, 24)),
        Transforms.ToTensor(),
        Transforms.Normalize((0.45, 0.448, 0.455), (0.082, 0.082, 0.082)),
    ])
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  
    txt_path = '/media/omnisky/D4T/JSH/faceFenlei/eye/mbhlk_hl_0128/mix_train.txt'
    vaild_txt = '/media/omnisky/D4T/JSH/faceFenlei/eye/mbhlk_hl_0128/mix_valid.txt'
    # model
    mixnet = MixNet(input_size=(24,24), num_classes=3)

    weight_dict = torch.load("/media/omnisky/D4T/JSH/faceFenlei/Projects/hul_eye_class/weight/relabel_04_mix_SGD_mutillabel_24_24_20210302/Mixnet_epoch_49.pth")
    new_state_dict = OrderedDict()
    for k, v in weight_dict.items():
        name = k[7:]
        new_state_dict[name] = v

    mixnet.load_state_dict(new_state_dict)
    # stat(net,(3,48,48))
    mixnet.to('cuda:0')
    mixnet.eval()
    
    vaild_data = mbhk_get_signal_eye(vaild_txt,vaild_ttrans)
    valid_data_loader = DataLoader(vaild_data,batch_size=128,shuffle=False,num_workers=12)
    score_list = [] #存储预测得分
    label_list = [] #存储真实标签 
    with torch.no_grad():
        for imgs,labels,_ in valid_data_loader:
            for timg in imgs:
                test_result = mixnet(timg.cuda())
                # result = torch.max(test_result,1)[1]
                result = torch.nn.functional.softmax(test_result,dim=1)
                
                score_list.extend(result.cpu().numpy())
                label_list.extend(labels.cpu().numpy())
    label_list = np.array(label_list).reshape((-1,1))
    score_list = np.array(score_list).reshape((-1,3))
    plot_RR(label_list,score_list,['0','1','2'])
コード例 #3
0
def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    data_trans = Transforms.Compose([
        Transforms.Resize((24, 24)),
        Transforms.ToTensor(),
        Transforms.Normalize((0.45, 0.448, 0.455), (0.082, 0.082, 0.082)),
    ])
    test_data = mbhk_get_signal_eye(img_path, transform=data_trans)
    test_data_loader = torch.utils.data.DataLoader(test_data,
                                                   batch_size=128,
                                                   shuffle=False,
                                                   num_workers=4)
    # 定义模型 输入图像24*48
    # net = MixNet(input_size=(24,48), num_classes=3)
    # weight_dict = torch.load("weight/change_mix_data_0202/Mixnet_epoch_59.pth")
    net = MixNet(input_size=(24, 24), num_classes=3)
    weight_dict = torch.load(
        "weight/mix_mbhk_change_signal_eye_24_24/Mixnet_epoch_59.pth")
    new_state_dict = OrderedDict()
    for k, v in weight_dict.items():
        name = k[7:]
        new_state_dict[name] = v

    net.load_state_dict(new_state_dict)
    # stat(net,(3,48,48))
    net.to(device)
    net.eval()
    acc = 0.0
    val_num = len(test_data)
    #创建文件夹存储错误识别的图片
    class_img = ["close_eye", "open_eye", "other"]
    error_class_mbhk_img_path = "./error_class/mbhk_img"
    error_class_change_img_path = "./error_class/change_img"
    if not os.path.exists(error_class_mbhk_img_path):
        os.makedirs(error_class_mbhk_img_path)
    if not os.path.exists(error_class_change_img_path):
        os.makedirs(error_class_change_img_path)
    for tp in class_img:
        if not os.path.exists(os.path.join(error_class_mbhk_img_path, tp)):
            os.mkdir(os.path.join(error_class_mbhk_img_path, tp))
        if not os.path.exists(os.path.join(error_class_change_img_path, tp)):
            os.mkdir(os.path.join(error_class_change_img_path, tp))
    #创建错误日志保存到txt文件
    error_lod_mbhk = open(
        os.path.join(error_class_mbhk_img_path, "error_mbhk_log.txt"), 'w')
    error_lod_change = open(
        os.path.join(error_class_change_img_path, "error_change_log.txt"), 'w')
    with torch.no_grad():
        for i in tqdm(range(len(test_data))):
            # img,res_label,[img_path,json_path]
            timg, label, tpath = test_data.__getitem__(i)
            # timg = timg.unsqueeze(0)
            #用于计数,防止重复操作
            count = 0
            for img in timg:
                #增加维度
                img = img.unsqueeze(0)
                # label = label.unsqueeze(0)
                outputs = net(img.to(device))
                result = torch.max(outputs, 1)[1]
                if result.item() != label and count == 0:
                    count += 1
                    if result == 0:
                        if "/imge" in tpath[0]:
                            #则说明是mbhk数据
                            shutil.copy(
                                tpath[0],
                                os.path.join(error_class_mbhk_img_path,
                                             "open_eye"))
                            shutil.copy(
                                tpath[1],
                                os.path.join(error_class_mbhk_img_path,
                                             "open_eye"))
                            error_lod_mbhk.write("{} {} {}\n".format(
                                tpath[0], label_idx[result.item()],
                                label_idx[label]))
                        else:
                            error_lod_change.write("{} {} {}\n".format(
                                tpath[0], label_idx[result.item()],
                                label_idx[label]))
                            shutil.copy(
                                tpath[0],
                                os.path.join(error_class_change_img_path,
                                             "open_eye"))
                            shutil.copy(
                                tpath[1],
                                os.path.join(error_class_change_img_path,
                                             "open_eye"))
                    elif result == 1:
                        if "/imge" in tpath[0]:
                            #则说明是mbhk数据
                            shutil.copy(
                                tpath[1],
                                os.path.join(error_class_mbhk_img_path,
                                             "close_eye"))
                            shutil.copy(
                                tpath[0],
                                os.path.join(error_class_mbhk_img_path,
                                             "close_eye"))
                            error_lod_mbhk.write("{} {} {}\n".format(
                                tpath[0], label_idx[result.item()],
                                label_idx[label]))
                        else:
                            error_lod_change.write("{} {} {}\n".format(
                                tpath[0], label_idx[result.item()],
                                label_idx[label]))
                            shutil.copy(
                                tpath[0],
                                os.path.join(error_class_change_img_path,
                                             "close_eye"))
                            shutil.copy(
                                tpath[1],
                                os.path.join(error_class_change_img_path,
                                             "close_eye"))
                    elif result == 2:
                        if "/imge" in tpath[0]:
                            #则说明是mbhk数据
                            shutil.copy(
                                tpath[1],
                                os.path.join(error_class_mbhk_img_path,
                                             "other"))
                            shutil.copy(
                                tpath[0],
                                os.path.join(error_class_mbhk_img_path,
                                             "other"))
                            error_lod_mbhk.write("{} {} {}\n".format(
                                tpath[0], label_idx[result.item()],
                                label_idx[label]))
                        else:
                            error_lod_change.write("{} {} {}\n".format(
                                tpath[0], label_idx[result.item()],
                                label_idx[label]))
                            shutil.copy(
                                tpath[0],
                                os.path.join(error_class_change_img_path,
                                             "other"))
                            shutil.copy(
                                tpath[1],
                                os.path.join(error_class_change_img_path,
                                             "other"))
    error_lod_mbhk.close()
    error_lod_change.close()
コード例 #4
0
#定义损失函数
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr = 0.001)
schedule = optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)

epoch_size = math.ceil(len(train_data)/batchsize)
maxiter = epoch*epoch_size
epoch_count = 0

for iteration in range(maxiter):
    acc = 0.0
    if iteration % epoch_size == 0:
        if epoch_count>0:
            schedule.step()
            model.eval()
            toal_loss = 0
            with torch.no_grad():
                for imgs,label in valid_data_loader:
                    test_result = model(imgs.cuda())
                    loss = loss_function(test_result,label.cuda())
                    result = torch.max(test_result,1)[1]
                    acc += (result == label.to(device)).sum().item()
                    toal_loss += loss
                writer.add_scalars("test_loss_acc",{"loss":toal_loss/len(vaild_data),"access":acc/len(vaild_data)})
                print("valid_loss:{},valid_access:{}".format(toal_loss/len(vaild_data),acc/len(vaild_data)))
            if epoch_count % 10 == 9:
                    torch.save(model.state_dict(),"/media/omnisky/D4T/JSH/faceFenlei/Projects/hul_eye_class/weight/set_lookdown_as_open_signaleye_24_24_20210301/Mixnet_epoch_{}.pth".format(epoch_count))
                    print("save weight success!!")
        train_data_loader = iter(DataLoader( dataset=train_data,batch_size=batchsize,shuffle=True,num_workers=12))
        epoch_count += 1
コード例 #5
0
def main():
    class_num = 3
    vaild_ttrans = Transforms.Compose([
        # Transforms.RandomVerticalFlip(p=0.5),
        # transforms.RandomRotation(30),
        # transforms.RandomCrop(100),
        # transforms.RandomResizedCrop(112),
        # Transforms.ColorJitter(brightness=0.5),
        # transforms.RandomErasing(p=0.2, scale=(0.02, 0.03), ratio=(0.3, 0.3), value=0, ),
        Transforms.Resize((24, 24)),
        Transforms.ToTensor(),
        Transforms.Normalize((0.45, 0.448, 0.455), (0.082, 0.082, 0.082)),
    ])
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  
    txt_path = '/media/omnisky/D4T/JSH/faceFenlei/eye/mbhlk_hl_0128/mix_train.txt'
    vaild_txt = '/media/omnisky/D4T/JSH/faceFenlei/eye/mbhlk_hl_0128/mix_valid.txt'
    # model
    mixnet = MixNet(input_size=(24,24), num_classes=3)

    weight_dict = torch.load("/media/omnisky/D4T/JSH/faceFenlei/Projects/hul_eye_class/weight/relabel_04_mix_SGD_mutillabel_24_24_20210302/Mixnet_epoch_49.pth")
    new_state_dict = OrderedDict()
    for k, v in weight_dict.items():
        name = k[7:]
        new_state_dict[name] = v

    mixnet.load_state_dict(new_state_dict)
    # stat(net,(3,48,48))
    mixnet.to('cuda:0')
    mixnet.eval()
    
    vaild_data = mbhk_get_signal_eye(vaild_txt,vaild_ttrans)
    valid_data_loader = DataLoader(vaild_data,batch_size=128,shuffle=False,num_workers=12)
    score_list = [] #存储预测得分
    label_list = [] #存储真实标签 
    with torch.no_grad():
        for imgs,labels,_ in valid_data_loader:
            for timg in imgs:
                test_result = mixnet(timg.cuda())
                # result = torch.max(test_result,1)[1]
                result = torch.nn.functional.softmax(test_result,dim=1)
                
                score_list.extend(result.cpu().numpy())
                label_list.extend(torch.nn.functional.one_hot(labels,num_classes=3).numpy())
    tlabel_list = np.array(label_list).reshape((-1,3))
    tscore_list = np.array(score_list).reshape((-1,3))
    # 调用sklearn,计算每个类别对应的fpr和tpr
    fpr_dict = dict()
    tpr_dict = dict()
    roc_auc_dict = dict()
    for i in range(class_num):
        fpr_dict[i],tpr_dict[i],_ = roc_curve(tlabel_list[:,i],tscore_list[:,i])
        roc_auc_dict[i] = auc(fpr_dict[i],tpr_dict[i])

    # Compute micro-average ROC curve and ROC area
    fpr_dict["micro"],tpr_dict["micro"],_ = roc_curve(tlabel_list.ravel(),tscore_list.ravel())
    roc_auc_dict["micro"] = auc(fpr_dict["micro"],tpr_dict["micro"])
    #绘制所有类别平均的roc曲线
    # macro
    # First aggregate all false positive rates
    all_fpr = np.unique(np.concatenate([fpr_dict[i] for i in range(class_num)]))
    # Then interpolate all ROC curves at this points
    mean_ptr = np.zeros_like(all_fpr)
    for i in range(class_num):
        mean_ptr += interp(all_fpr,fpr_dict[i],tpr_dict[i])
    # Finally average it and compute AUC
    mean_ptr /= class_num
    fpr_dict["macro"] = all_fpr
    tpr_dict["macro"] = mean_ptr
    roc_auc_dict["macro"] = auc(fpr_dict['macro'],tpr_dict["macro"])

    plt.figure()
    lw = 2
    plt.plot(fpr_dict["micro"], tpr_dict["micro"],
             label='micro-average ROC curve (area = {0:0.2f})'
                   ''.format(roc_auc_dict["micro"]),
             color='deeppink', linestyle=':', linewidth=4)
 
    plt.plot(fpr_dict["macro"], tpr_dict["macro"],
             label='macro-average ROC curve (area = {0:0.2f})'
                   ''.format(roc_auc_dict["macro"]),
             color='navy', linestyle=':', linewidth=4)
 
    colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
    for i, color in zip(range(class_num), colors):
        plt.plot(fpr_dict[i], tpr_dict[i], color=color, lw=lw,
                 label='ROC curve of class {0} (area = {1:0.2f})'
                       ''.format(i, roc_auc_dict[i]))
    plt.plot([0, 1], [0, 1], 'k--', lw=lw)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Some extension of Receiver operating characteristic to multi-class')
    plt.legend(loc="lower right")
    plt.savefig('set113_roc.jpg')
    plt.show()
コード例 #6
0
def main():
    label_idx = {0: "open_eye", 1: "close_eye", 2: "other"}
    #tensorboardX初始化
    writer = SummaryWriter("run/change_mix_iniput_24_48")
    train_txt_path = "/media/omnisky/D4T/JSH/faceFenlei/eye/mbhlk_hl_0128/mix_train.txt"
    valid_txt_path = "/media/omnisky/D4T/JSH/faceFenlei/eye/mbhlk_hl_0128/mix_valid.txt"
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    #定义预处理
    transforms_function = {
        'train':
        Transforms.Compose([
            Transforms.RandomVerticalFlip(p=0.5),
            # transforms.RandomRotation(30),
            # transforms.RandomCrop(100),
            # transforms.RandomResizedCrop(112),
            Transforms.ColorJitter(brightness=0.5),
            # transforms.RandomErasing(p=0.2, scale=(0.02, 0.03), ratio=(0.3, 0.3), value=0, ),
            # Transforms.Resize((48, 48)),
            Transforms.ToTensor(),
            Transforms.Normalize((0.407, 0.405, 0.412), (0.087, 0.087, 0.087)),
        ]),
        'test':
        Transforms.Compose([
            # Transforms.Resize((48, 48)),
            Transforms.ToTensor(),
            Transforms.Normalize((0.407, 0.405, 0.412), (0.087, 0.087, 0.087)),
        ])
    }
    # 定义数据集
    train_data = mbhk_data(train_txt_path,
                           transform=transforms_function['train'])
    valid_data = mbhk_data(valid_txt_path,
                           transform=transforms_function['test'])
    # train_size = int(0.9 * len(train_data))
    # valid_size = len(train_data) - train_size
    # train_dataset,vaild_dataset = torch.utils.data.random_split(train_data,[train_size,valid_size])
    # train_data_loader = torch.utils.data.DataLoader(train_data,batch_size=256,shuffle=True,num_workers=8)

    test_data_loader = DataLoader(valid_data,
                                  batch_size=128,
                                  shuffle=False,
                                  num_workers=8)
    #定义模型
    model = MixNet(input_size=(24, 48), num_classes=3)
    model.to(device)
    #定义多GPU训练
    model = torch.nn.DataParallel(model, device_ids=[0, 1]).cuda()
    #定义损失函数
    loss_function = nn.CrossEntropyLoss()
    # 定义优化器
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    # 定义学习率下降
    schedule = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
    best_acc = 0.0

    # 计算一个epoch的步数
    epoch_size = math.ceil(len(train_data) / 128)
    #得到迭代次数
    maxiter = 90 * epoch_size
    epoch = 0
    for iteration in range(maxiter):
        acc = 0.0
        if iteration % epoch_size == 0:
            if epoch > 0:
                schedule.step()
                model.eval()
                toal_loss = 0
                with torch.no_grad():
                    for timages, tlabels, _ in test_data_loader:
                        test_result = model(timages.cuda())
                        loss = loss_function(test_result, tlabels.cuda())
                        result = torch.max(test_result, 1)[1]
                        acc += (result == tlabels.to(device)).sum().item()
                        toal_loss += loss
                    writer.add_scalars(
                        "test_loss_acc", {
                            "loss": toal_loss / len(test_data_loader),
                            "access": acc / len(valid_data)
                        }, epoch)
                if epoch % 10 == 9:
                    torch.save(
                        model.state_dict(),
                        "./weight/change_mix_data_0202/Mixnet_epoch_{}.pth".
                        format(epoch))
                    print("save weight success!!")
            train_data_loader = iter(
                DataLoader(dataset=train_data,
                           batch_size=128,
                           shuffle=True,
                           num_workers=12))
            epoch += 1

        model.train()
        load_t0 = time.time()
        images, labels, _ = next(train_data_loader)
        optimizer.zero_grad()
        images.cuda()
        outputs = model(images)
        loss = loss_function(outputs, labels.cuda())
        loss.backward()
        optimizer.step()

        load_t1 = time.time()
        batch_time = load_t1 - load_t0
        eta = int(batch_time * (maxiter - iteration))
        print(
            "Epoch:{}/{} || Epochiter:{}/{} || loss:{:.4f}||Batchtime:{:.4f}||ETA:{}"
            .format(epoch, 90, (iteration % epoch_size) + 1, epoch_size,
                    loss.item(), batch_time,
                    str(datetime.timedelta(seconds=eta))))
        writer.add_scalar("loss", loss, iteration)
def dete_picture():

    eye_class_dict = {0: "open_eye", 1: "close_eye", 2: "other"}
    point_nums = 24
    threshold = [0.6, 0.7, 0.7]
    data_trans = Transforms.Compose([
        Transforms.Resize((24, 24)),
        Transforms.ToTensor(),
        Transforms.Normalize((0.45, 0.448, 0.455), (0.082, 0.082, 0.082)),
        # Transforms.Normalize((0.407, 0.405, 0.412), (0.087, 0.087, 0.087)),
    ])
    mixnet = MixNet(input_size=(24, 24), num_classes=3)
    # eye_class_dict = {0:"open_eye",1:"close_eye"}
    # weight_dict = torch.load("weight/signal_eye/Mixnet_epoch_29.pth")
    weight_dict = torch.load(
        "/media/omnisky/D4T/JSH/faceFenlei/Projects/hul_eye_class/weight/relabel_mix_24_24_20210302/Mixnet_epoch_59.pth"
    )
    new_state_dict = OrderedDict()
    for k, v in weight_dict.items():
        name = k[7:]
        new_state_dict[name] = v

    mixnet.load_state_dict(new_state_dict)
    # stat(net,(3,48,48))
    mixnet.to('cuda:0')
    mixnet.eval()

    pnet, rnet, onet = create_mtcnn_net(
        p_model_path=r'model_store/final/pnet_epoch_19.pt',
        r_model_path=r'model_store/final/rnet_epoch_7.pt',
        o_model_path=r'model_store/final/onet_epoch_92.pt',
        use_cuda=True)
    mtcnn_detector = MtcnnDetector(pnet=pnet,
                                   rnet=rnet,
                                   onet=onet,
                                   min_face_size=24,
                                   threshold=threshold)
    img_file = "/media/omnisky/D4T/JSH/faceFenlei/Projects/hul_eye_class/test_video/caiji_0123"
    img_save = "/media/omnisky/D4T/JSH/faceFenlei/Projects/hul_eye_class/result_video/relabel_img_result_adma_01"
    img_path = [
        os.path.join(img_file, file_name)
        for file_name in glob.glob(os.path.join(img_file, "*.jpg"))
    ]

    # videos_root_path = 'test_video/DMS_RAW_Nebula_20201201-143038_518.mp4'
    # save_path_root = 'result_video/24_24_DMS_RAW_Nebula_20201201-143038_518.avi'

    # cap = cv2.VideoCapture(videos_root_path)
    # fourcc = cv2.VideoWriter_fourcc(*'XVID')
    # fps = cap.get(cv2.CAP_PROP_FPS)
    # size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
    # tpa
    # fname = os.path.splitext(os.path.split(tpa)[1])[0]
    # save_path = os.path.join("/media/omnisky/D4T/JSH/faceFenlei/Projects/hul_eye_class/result_video/data(2)",fname+".avi")
    # out = cv2.VideoWriter(save_path_root, fourcc, fps, size)
    for img_p in tqdm(img_path):
        frame = cv2.imread(img_p)

        copy_frame = frame.copy()
        left_right_eye = []
        bboxs, landmarks, wearmask = mtcnn_detector.detect_face(frame,
                                                                rgb=True)

        if landmarks is not None:
            for i in range(landmarks.shape[0]):
                landmarks_one = landmarks[i, :]
                landmarks_one = landmarks_one.reshape((point_nums, 2))
                left_eye = np.array(landmarks_one[[6, 8, 10, 11, 14], :])
                xmin = np.min(left_eye[:, 0])
                ymin = np.min(left_eye[:, 1])
                xmax = np.max(left_eye[:, 0])
                ymax = np.max(left_eye[:, 1])
                left_right_eye.append([xmin, ymin, xmax, ymax])
                # cv2.rectangle(frame,(int(xmin),int(ymin)),(int(xmax),int(ymax)),(0,255,0),2)

                right_eye = np.array(landmarks_one[[7, 9, 12, 13, 15], :])
                xmin = np.min(right_eye[:, 0])
                ymin = np.min(right_eye[:, 1])
                xmax = np.max(right_eye[:, 0])
                ymax = np.max(right_eye[:, 1])
                left_right_eye.append([xmin, ymin, xmax, ymax])
                # cv2.rectangle(frame,(int(xmin),int(ymin)),(int(xmax),int(ymax)),(0,255,0),2)
                for j in [*left_eye, *right_eye]:
                    cv2.circle(frame, (int(j[0]), int(j[1])), 2, (255, 0, 0),
                               -1)

            crop_img = []
            for xmin, ymin, xmax, ymax in left_right_eye:
                w, h = xmax - xmin, ymax - ymin
                # 随机扩展大小0.05-0.15
                k = 0.1
                ratio = h / w
                if ratio > 1:
                    ratio = ratio - 1
                    xmin -= (ratio / 2 * w + k * h)
                    ymin -= (k * h)
                    xmax += (ratio / 2 * w + k * h)
                    ymax += (k * h)

                else:
                    ratio = w / h - 1
                    xmin -= (k * w)
                    ymin -= (ratio / 2 * h + k * w)
                    xmax += (k * w)
                    ymax += (ratio / 2 * h + k * w)
                cv2.rectangle(frame, (int(xmin), int(ymin)),
                              (int(xmax), int(ymax)), (0, 255, 255), 2)
                temp_img = copy_frame[int(ymin):int(ymax), int(xmin):int(xmax)]
                # temp_img = cv2.resize(temp_img,(24,24))
                crop_img.append(temp_img)
            if len(crop_img) < 2:
                img_name = os.path.split(img_p)[-1]
                cv2.imwrite(os.path.join(img_save, img_name), frame)
                # out.write(frame)
                continue
            # compose_img = np.hstack((crop_img[0],crop_img[1]))
            result_buff = []
            score_buff = []
            for i in crop_img:
                i = cv2.cvtColor(i, cv2.COLOR_BGR2RGB)

                compose_img = Image.fromarray(i)
                img = data_trans(compose_img)
                img = img.unsqueeze(0)
                with torch.no_grad():
                    outputs = mixnet(img.to('cuda:0'))
                    spft_max = torch.nn.functional.softmax(outputs, dim=1)
                    score_buff.append(spft_max.cpu().numpy())
                    # 0,1->data,id
                    score, result = torch.max(spft_max, 1)
                    result_buff.append([result.item(), score])
            bias = 30
            eye_bias = 100
            for i in range(2):
                t_result = result_buff[i][0]
                if 0 == t_result:
                    # eye_class = "close_eye"
                    # cv2.putText(frame,eye_class,(int(xmax), int(ymax)-20),cv2.FONT_HERSHEY_COMPLEX,1.0,(0,255,0) \
                    # ,thickness=2)
                    eye_class = "open_eye:{:.2f}".format(
                        result_buff[i][1].cpu().item())
                    cv2.putText(frame,eye_class,(int(left_right_eye[i][0])-eye_bias, int(left_right_eye[i][1])-bias),cv2.FONT_HERSHEY_COMPLEX,0.6,(255,0,255) \
                    ,thickness=2)
                elif 1 == t_result:
                    # eye_class = "open_eye"
                    # cv2.putText(frame,eye_class,(int(xmax), int(ymax)-20),cv2.FONT_HERSHEY_COMPLEX,1.0,(255,0,255) \
                    # ,thickness=2)

                    eye_class = "close_eye:{:.2f}".format(
                        result_buff[i][1].cpu().item())
                    cv2.putText(frame,eye_class,(int(left_right_eye[i][0])-eye_bias, int(left_right_eye[i][1])-bias),cv2.FONT_HERSHEY_COMPLEX,0.6,(0,255,0) \
                    ,thickness=2)
                else:
                    eye_class = "other:{:.2f}".format(
                        result_buff[i][1].cpu().item())
                    cv2.putText(frame,eye_class,(int(left_right_eye[i][0])-eye_bias, int(left_right_eye[i][1])-bias),cv2.FONT_HERSHEY_COMPLEX,0.6,(0,0,255) \
                    ,thickness=2)
                # bias += 30
                eye_bias = 0
                # left_eye
                left_eye_open, left_eye_close, left_eye_other = score_buff[0][
                    0]
                cv2.putText(frame,"left_open:{:.2f}".format(left_eye_open) ,(10, 20),cv2.FONT_HERSHEY_COMPLEX,0.6,(20,150,0) \
                    ,thickness=2)
                cv2.putText(frame,"left_close:{:.2f}".format(left_eye_close) ,(10, 40),cv2.FONT_HERSHEY_COMPLEX,0.6,(20,150,0) \
                    ,thickness=2)
                cv2.putText(frame,"left_other:{:.2f}".format(left_eye_other) ,(10, 60),cv2.FONT_HERSHEY_COMPLEX,0.6,(20,150,0) \
                    ,thickness=2)

                #right_eye
                right_eye_open, right_eye_close, right_eye_other = score_buff[
                    1][0]
                cv2.putText(frame,"left_open:{:.2f}".format(right_eye_open) ,(200, 20),cv2.FONT_HERSHEY_COMPLEX,0.6,(20,150,0) \
                    ,thickness=2)
                cv2.putText(frame,"left_close:{:.2f}".format(right_eye_close) ,(200, 40),cv2.FONT_HERSHEY_COMPLEX,0.6,(20,150,0) \
                    ,thickness=2)
                cv2.putText(frame,"left_other:{:.2f}".format(right_eye_other) ,(200, 60),cv2.FONT_HERSHEY_COMPLEX,0.6,(20,150,0) \
                    ,thickness=2)
            # eye_class = "open_eye" if 0 in t_result else "close_eye"
        img_name = os.path.split(img_p)[-1]
        cv2.imwrite(os.path.join(img_save, img_name), frame)
def show_with_camera():

    eye_class_dict = {0: "open_eye", 1: "close_eye", 2: "other"}
    point_nums = 24
    threshold = [0.6, 0.7, 0.7]
    data_trans = Transforms.Compose([
        Transforms.Resize((24, 24)),
        Transforms.ToTensor(),
        Transforms.Normalize((0.45, 0.448, 0.455), (0.082, 0.082, 0.082)),
        # Transforms.Normalize((0.407, 0.405, 0.412), (0.087, 0.087, 0.087)),
    ])
    mixnet = MixNet(input_size=(24, 24), num_classes=3)
    # eye_class_dict = {0:"open_eye",1:"close_eye"}
    # weight_dict = torch.load("weight/signal_eye/Mixnet_epoch_29.pth")
    weight_dict = torch.load(
        "/media/omnisky/D4T/JSH/faceFenlei/Projects/hul_eye_class/weight/mix_mbhk_change_signal_eye_24_24/Mixnet_epoch_59.pth"
    )
    new_state_dict = OrderedDict()
    for k, v in weight_dict.items():
        name = k[7:]
        new_state_dict[name] = v

    mixnet.load_state_dict(new_state_dict)
    # stat(net,(3,48,48))
    mixnet.to('cuda:0')
    mixnet.eval()

    pnet, rnet, onet = create_mtcnn_net(
        p_model_path=r'model_store/final/pnet_epoch_19.pt',
        r_model_path=r'model_store/final/rnet_epoch_7.pt',
        o_model_path=r'model_store/final/onet_epoch_92.pt',
        use_cuda=True)
    mtcnn_detector = MtcnnDetector(pnet=pnet,
                                   rnet=rnet,
                                   onet=onet,
                                   min_face_size=24,
                                   threshold=threshold)
    videos_root_path = 'test_video/20200506143954001_0.avi'
    save_path_root = 'result_video/camera_test_20210301.avi'

    cap = cv2.VideoCapture(0)
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    fps = cap.get(cv2.CAP_PROP_FPS)
    size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
            int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
    # tpa
    # fname = os.path.splitext(os.path.split(tpa)[1])[0]
    # save_path = os.path.join("/media/omnisky/D4T/JSH/faceFenlei/Projects/hul_eye_class/result_video/data(2)",fname+".avi")
    out = cv2.VideoWriter(save_path_root, fourcc, fps, size)
    while True:
        ret, frame = cap.read()

        if ret:
            copy_frame = frame.copy()
            left_right_eye = []
            bboxs, landmarks, wearmask = mtcnn_detector.detect_face(frame,
                                                                    rgb=True)

            if landmarks is not None:
                for i in range(landmarks.shape[0]):
                    landmarks_one = landmarks[i, :]
                    landmarks_one = landmarks_one.reshape((point_nums, 2))
                    left_eye = np.array(landmarks_one[[6, 8, 10, 11, 14], :])
                    xmin = np.min(left_eye[:, 0])
                    ymin = np.min(left_eye[:, 1])
                    xmax = np.max(left_eye[:, 0])
                    ymax = np.max(left_eye[:, 1])
                    left_right_eye.append([xmin, ymin, xmax, ymax])
                    # cv2.rectangle(frame,(int(xmin),int(ymin)),(int(xmax),int(ymax)),(0,255,0),2)

                    right_eye = np.array(landmarks_one[[7, 9, 12, 13, 15], :])
                    xmin = np.min(right_eye[:, 0])
                    ymin = np.min(right_eye[:, 1])
                    xmax = np.max(right_eye[:, 0])
                    ymax = np.max(right_eye[:, 1])
                    left_right_eye.append([xmin, ymin, xmax, ymax])
                    # cv2.rectangle(frame,(int(xmin),int(ymin)),(int(xmax),int(ymax)),(0,255,0),2)
                    for j in [*left_eye, *right_eye]:
                        cv2.circle(frame, (int(j[0]), int(j[1])), 2,
                                   (255, 0, 0), -1)

                crop_img = []
                for xmin, ymin, xmax, ymax in left_right_eye:
                    w, h = xmax - xmin, ymax - ymin
                    # 随机扩展大小0.05-0.15
                    k = 0.1
                    ratio = h / w
                    if ratio > 1:
                        ratio = ratio - 1
                        xmin -= (ratio / 2 * w + k * h)
                        ymin -= (k * h)
                        xmax += (ratio / 2 * w + k * h)
                        ymax += (k * h)

                    else:
                        ratio = w / h - 1
                        xmin -= (k * w)
                        ymin -= (ratio / 2 * h + k * w)
                        xmax += (k * w)
                        ymax += (ratio / 2 * h + k * w)
                    cv2.rectangle(frame, (int(xmin), int(ymin)),
                                  (int(xmax), int(ymax)), (0, 255, 255), 2)
                    temp_img = copy_frame[int(ymin):int(ymax),
                                          int(xmin):int(xmax)]
                    # temp_img = cv2.resize(temp_img,(24,24))
                    crop_img.append(temp_img)
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
                frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)
                if len(crop_img) < 2:
                    cv2.imshow("test", frame)
                    tget_in = cv2.waitKey(10)
                    # print(ord('q'),tget_in)
                    if tget_in == ord('q'):
                        print("get out")
                        break
                    out.write(frame)
                    continue
                # compose_img = np.hstack((crop_img[0],crop_img[1]))
                t_result = []
                for i in crop_img:
                    i = cv2.cvtColor(i, cv2.COLOR_BGR2GRAY)
                    i = cv2.cvtColor(i, cv2.COLOR_GRAY2RGB)

                    compose_img = Image.fromarray(i)
                    img = data_trans(compose_img)
                    img = img.unsqueeze(0)
                    with torch.no_grad():
                        outputs = mixnet(img.to('cuda:0'))
                        result = torch.max(outputs, 1)[1]
                        t_result.append(result.item())
                if 0 in t_result:
                    eye_class = "open_eye"
                    cv2.putText(frame,eye_class,(int(xmax), int(ymax)-20),cv2.FONT_HERSHEY_COMPLEX,1.0,(255,0,255) \
                    ,thickness=2)
                elif 1 in t_result:
                    eye_class = "close_eye"
                    cv2.putText(frame,eye_class,(int(xmax), int(ymax)-20),cv2.FONT_HERSHEY_COMPLEX,1.0,(0,255,0) \
                    ,thickness=2)
                else:
                    eye_class = "other"
                    cv2.putText(frame,eye_class,(int(xmax), int(ymax)-20),cv2.FONT_HERSHEY_COMPLEX,1.0,(0,0,255) \
                    ,thickness=2)
                cv2.imshow("test", frame)
                tget_in = cv2.waitKey(10)
                if tget_in == ord('q'):
                    print("get out")
                    break
                # eye_class = "open_eye" if 0 in t_result else "close_eye"

                # cv2.putText(frame,eye_class,(int(xmax), int(ymax)-20),cv2.FONT_HERSHEY_COMPLEX,1.0,(255,0,255) \
                #     if 0 in t_result else (255,255,0),thickness=2)
            out.write(frame)
        else:
            print("finish")
            break
def dete_signal_video():

    eye_class_dict = {0: "open_eye", 1: "close_eye", 2: "other"}
    point_nums = 24
    threshold = [0.6, 0.7, 0.7]
    data_trans = Transforms.Compose([
        Transforms.Resize((24, 24)),
        Transforms.ToTensor(),
        Transforms.Normalize((0.45, 0.448, 0.455), (0.082, 0.082, 0.082)),
        # Transforms.Normalize((0.407, 0.405, 0.412), (0.087, 0.087, 0.087)),
    ])
    mixnet = MixNet(input_size=(24, 24), num_classes=3)
    # eye_class_dict = {0:"open_eye",1:"close_eye"}
    # weight_dict = torch.load("weight/signal_eye/Mixnet_epoch_29.pth")
    weight_dict = torch.load(
        "/media/omnisky/D4T/JSH/faceFenlei/Projects/hul_eye_class/weight/relabel_04_mix_SGD_mutillabel_24_24_20210302/Mixnet_epoch_49.pth"
    )
    new_state_dict = OrderedDict()
    for k, v in weight_dict.items():
        name = k[7:]
        new_state_dict[name] = v

    mixnet.load_state_dict(new_state_dict)
    # stat(net,(3,48,48))
    mixnet.to('cuda:0')
    mixnet.eval()

    pnet, rnet, onet = create_mtcnn_net(
        p_model_path=r'model_store/final/pnet_epoch_19.pt',
        r_model_path=r'model_store/final/rnet_epoch_7.pt',
        o_model_path=r'model_store/final/onet_epoch_92.pt',
        use_cuda=True)
    mtcnn_detector = MtcnnDetector(pnet=pnet,
                                   rnet=rnet,
                                   onet=onet,
                                   min_face_size=24,
                                   threshold=threshold)
    videos_root_path = 'test_video/hhh/02_65_6504_0_be4ba2aeac264ed992aae74c15b91b18.mp4'
    save_path_root = 'result_video/debug_test.avi'

    cap = cv2.VideoCapture(videos_root_path)
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    fps = cap.get(cv2.CAP_PROP_FPS)
    size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
            int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
    # tpa
    fname = os.path.splitext(os.path.split(videos_root_path)[1])[0]
    save_path = os.path.join(
        "/media/omnisky/D4T/JSH/faceFenlei/Projects/hul_eye_class/result_video/data(2)",
        fname + ".avi")
    out = cv2.VideoWriter(save_path_root, fourcc, fps, size)
    while True:
        ret, frame = cap.read()

        if ret:
            copy_frame = frame.copy()
            left_right_eye = []
            bboxs, landmarks, wearmask = mtcnn_detector.detect_face(frame,
                                                                    rgb=True)
            temp_path, trmp_name = os.path.split(save_path)
            # trmp_name = os.path.splitext(trmp_name)[0] + "{:04d}.jpg".format(img_count)
            # tsave_path = os.path.join(temp_path, trmp_name)
            if landmarks is not None:
                eye_wild_buf = []
                for i in range(landmarks.shape[0]):
                    landmarks_one = landmarks[i, :]
                    landmarks_one = landmarks_one.reshape((point_nums, 2))
                    left_eye = np.array(landmarks_one[[6, 8, 10, 11, 14], :])
                    xmin = np.min(left_eye[:, 0])
                    ymin = np.min(left_eye[:, 1])
                    xmax = np.max(left_eye[:, 0])
                    ymax = np.max(left_eye[:, 1])
                    left_right_eye.append([xmin, ymin, xmax, ymax])

                    # cv2.rectangle(frame,(int(xmin),int(ymin)),(int(xmax),int(ymax)),(0,255,0),2)

                    right_eye = np.array(landmarks_one[[7, 9, 12, 13, 15], :])
                    xmin = np.min(right_eye[:, 0])
                    ymin = np.min(right_eye[:, 1])
                    xmax = np.max(right_eye[:, 0])
                    ymax = np.max(right_eye[:, 1])
                    left_right_eye.append([xmin, ymin, xmax, ymax])
                    # cv2.rectangle(frame,(int(xmin),int(ymin)),(int(xmax),int(ymax)),(0,255,0),2)
                    #绘制眼睛点
                    # for j in [*left_eye,*right_eye]:
                    #     cv2.circle(frame, (int(j[0]), int(j[1])), 2, (255, 0, 0), -1)

                crop_img = []
                for xmin, ymin, xmax, ymax in left_right_eye:
                    w, h = xmax - xmin, ymax - ymin
                    # 随机扩展大小0.05-0.15
                    k = 0.1
                    ratio = h / w
                    if ratio > 1:
                        ratio = ratio - 1
                        xmin -= (ratio / 2 * w + k * h)
                        ymin -= (k * h)
                        xmax += (ratio / 2 * w + k * h)
                        ymax += (k * h)

                    else:
                        ratio = w / h - 1
                        xmin -= (k * w)
                        ymin -= (ratio / 2 * h + k * w)
                        xmax += (k * w)
                        ymax += (ratio / 2 * h + k * w)
                    eye_wild_buf.append(w)
                    cv2.rectangle(frame, (int(xmin), int(ymin)),
                                  (int(xmax), int(ymax)), (0, 255, 255), 1)
                    # 输出眼睛像素的长宽

                    temp_img = copy_frame[int(ymin):int(ymax),
                                          int(xmin):int(xmax)]
                    # temp_img = cv2.resize(temp_img,(24,24))
                    crop_img.append(temp_img)
                if len(crop_img) < 2:

                    cv2.imwrite(tsave_path, frame)
                    # out.write(frame)
                    continue
                # compose_img = np.hstack((crop_img[0],crop_img[1]))
            result_buff = []
            score_buff = []
            for i in crop_img:
                i = cv2.cvtColor(i, cv2.COLOR_BGR2RGB)
                t1 = time.time()
                compose_img = Image.fromarray(i)
                img = data_trans(compose_img)
                img = img.unsqueeze(0)
                with torch.no_grad():
                    outputs = mixnet(img.to('cuda:0'))
                    spft_max = torch.nn.functional.softmax(outputs, dim=1)
                    # 左眼右眼,分别三个类别的分数
                    score_buff.append(spft_max.cpu().numpy())
                    # 0,1->data,id
                    score, result = torch.max(spft_max, 1)
                    # result:最大值的id score:最大值的分数
                    result_buff.append([result.item(), score])
                run_time = time.time() - t1
                #0.005819
            bias = 30
            eye_bias = 100
            for i in range(2):
                t_result = result_buff[i][0]
                #眼睛抠图的宽度
                eye_w = eye_wild_buf[i]
                cv2.putText(frame,"w:{}".format(int(eye_w)),(int(left_right_eye[i][0])-eye_bias, int(left_right_eye[i][1])-50),cv2.FONT_HERSHEY_COMPLEX,0.6,(255,0,255) \
                    ,thickness=2)
                if 0 == t_result:
                    # eye_class = "close_eye"
                    # cv2.putText(frame,eye_class,(int(xmax), int(ymax)-20),cv2.FONT_HERSHEY_COMPLEX,1.0,(0,255,0) \
                    # ,thickness=2)
                    eye_class = "open_eye:{:.2f}".format(
                        result_buff[i][1].cpu().item())
                    cv2.putText(frame,eye_class,(int(left_right_eye[i][0])-eye_bias, int(left_right_eye[i][1])-bias),cv2.FONT_HERSHEY_COMPLEX,0.6,(255,0,255) \
                    ,thickness=2)
                elif 1 == t_result:
                    # eye_class = "open_eye"
                    # cv2.putText(frame,eye_class,(int(xmax), int(ymax)-20),cv2.FONT_HERSHEY_COMPLEX,1.0,(255,0,255) \
                    # ,thickness=2)

                    eye_class = "close_eye:{:.2f}".format(
                        result_buff[i][1].cpu().item())
                    cv2.putText(frame,eye_class,(int(left_right_eye[i][0])-eye_bias, int(left_right_eye[i][1])-bias),cv2.FONT_HERSHEY_COMPLEX,0.6,(0,255,0) \
                    ,thickness=2)
                else:
                    eye_class = "other:{:.2f}".format(
                        result_buff[i][1].cpu().item())
                    cv2.putText(frame,eye_class,(int(left_right_eye[i][0])-eye_bias, int(left_right_eye[i][1])-bias),cv2.FONT_HERSHEY_COMPLEX,0.6,(0,0,255) \
                    ,thickness=2)
                # bias += 30
                eye_bias = 0
                # left_eye
                left_eye_open, left_eye_close, left_eye_other = score_buff[0][
                    0]
                cv2.putText(frame,"left_open:{:.2f}".format(left_eye_open) ,(10, 20),cv2.FONT_HERSHEY_COMPLEX,0.6,(20,150,0) \
                    ,thickness=2)
                cv2.putText(frame,"left_close:{:.2f}".format(left_eye_close) ,(10, 40),cv2.FONT_HERSHEY_COMPLEX,0.6,(20,150,0) \
                    ,thickness=2)
                cv2.putText(frame,"left_other:{:.2f}".format(left_eye_other) ,(10, 60),cv2.FONT_HERSHEY_COMPLEX,0.6,(20,150,0) \
                    ,thickness=2)

                #right_eye
                right_eye_open, right_eye_close, right_eye_other = score_buff[
                    1][0]
                cv2.putText(frame,"left_open:{:.2f}".format(right_eye_open) ,(200, 20),cv2.FONT_HERSHEY_COMPLEX,0.6,(20,150,0) \
                    ,thickness=2)
                cv2.putText(frame,"left_close:{:.2f}".format(right_eye_close) ,(200, 40),cv2.FONT_HERSHEY_COMPLEX,0.6,(20,150,0) \
                    ,thickness=2)
                cv2.putText(frame,"left_other:{:.2f}".format(right_eye_other) ,(200, 60),cv2.FONT_HERSHEY_COMPLEX,0.6,(20,150,0) \
                    ,thickness=2)
            # 计算最大概率的标号
            max_id,max_score = (result_buff[0][0],result_buff[0][1].cpu().item()) if \
                result_buff[0][1].cpu().item()>result_buff[1][1].cpu().item() else (result_buff[1][0],result_buff[1][1].cpu().item())
            # 测试信息
            eye_wild_buf_info = "w:[{:.2f},{:.2f}]".format(
                eye_wild_buf[0], eye_wild_buf[1])
            # 测试时那个眼镜框最大
            max_wilde_left_right = 0 if eye_wild_buf[0] > eye_wild_buf[1] else 1
            # 获得最大宽度框的id和分数
            # 宽度最大的 id 和分数 宽度第二大的 id和分数
            max_wilde_id,max_wilde_score,max_wiled_second_id,max_wilde_second_score = (result_buff[0][0],result_buff[0][1].cpu().item(),result_buff[1][0],result_buff[1][1].cpu().item()) if \
                max_wilde_left_right==0 else (result_buff[1][0],result_buff[1][1].cpu().item(),result_buff[0][0],result_buff[0][1].cpu().item())

            score_buff_info = "score:[left: {:.2f}] [right: {:.2f}]".format(
                score_buff[0][0][2], score_buff[1][0][2])
            cv2.putText(frame,eye_wild_buf_info,(400,80),cv2.FONT_HERSHEY_COMPLEX,0.6,(255,0,0) \
                ,thickness=2)
            cv2.putText(frame,score_buff_info,(400,100),cv2.FONT_HERSHEY_COMPLEX,0.6,(255,0,0) \
            ,thickness=2)

            # 如果

            # if np.any(np.array(eye_wild_buf[:2])<19.0 )and max_score < 0.9 or np.any(np.array(eye_wild_buf[:2])<17.0 ) or np.any(np.array([score_buff[0][0][2],score_buff[1][0][2]])>= 0.5) and \
            #     max_score<0.9 or max_id==2:
            # 添加最大框                                                                                                            概率最大id=2 宽度最大的id=2
            # if (eye_wild_buf[max_wilde_left_right]<17.0 ) or ((max_wilde_score>= 0.5) and \
            #     max_wilde_id==2 and max_wilde_second_score<0.85)  or max_id==2 and (max_wilde_score < 0.8 and max_wilde_id != 2) or (max_id==2 and max_wilde_id == 2 and(max_wilde_second_score<0.8) ) or \
            #         (max_wilde_id == 2 and max_wiled_second_id==2 and (max_wilde_second_score>0.5 or max_wilde_score>0.5)) or ( eye_wild_buf[ 0 if max_wilde_left_right else 1]<17.0 ) or \
            #             ((eye_wild_buf[ 0 if max_wilde_left_right else 1]>23 and max_wilde_second_score>0.8 and max_wilde_id==2) or \
            #                 (eye_wild_buf[max_wilde_left_right]>23 and max_wilde_score >0.8 and max_wiled_second_id==2)):
            # 左眼右眼宽度大于23 且概率大于0.8 且id=2
            # 存在小于17像素的框且最大宽度的分数小于0.8
            # 存在other概率大于0.5
            # 存在小于10像素直接判断为other


            if ((eye_wild_buf[ 0 if max_wilde_left_right else 1]>23 and max_wilde_second_score>0.8 and max_wiled_second_id==2) or \
                (eye_wild_buf[ max_wilde_left_right]>23 and max_wilde_score >0.8 and max_wilde_score==2) or \
                (np.any(np.array(eye_wild_buf[:2])<17.0) and (max_wilde_score<0.8)) or
                ((max_wilde_id==2 and max_wilde_score>0.5 and max_wilde_second_score<0.9) or (max_wiled_second_id==2 and max_wilde_second_score>0.5 and max_wilde_score<0.9)) or\
                (np.any(np.array(eye_wild_buf[:2])<10.0))
                    ):
                # 如果像素小于19且最大概率的眼睛小于0.9 或 任何一个像素小于12 且 max分数小于0.9 或 other
                # 2.任意一个other>=50
                cv2.putText(frame,"other",(400,60),cv2.FONT_HERSHEY_COMPLEX,0.6,(0,0,255) \
                ,thickness=2)
            # elif np.any(np.array([score_buff[0][0][1],score_buff[1][0][1]])>= 0.85)  \
            #      or (max_id==1 and max_score>0.750):
            elif (max_wilde_id==1 and max_wilde_score>=0.80)  \
                    or (max_id==1 and max_score>0.750):
                # elif (max_wilde_score >= 0.85) and max_wilde_id==1  \
                #      or (max_wilde_id==1 and max_wilde_score>0.750):
                # 任意一个闭眼概率大于0.9
                # 最大值是闭眼且概率大于0.75
                cv2.putText(frame,"close",(400,60),cv2.FONT_HERSHEY_COMPLEX,0.6,(0,255,0) \
                ,thickness=2)
            else:
                cv2.putText(frame,"open",(400,60),cv2.FONT_HERSHEY_COMPLEX,0.6,(255,0,0) \
                ,thickness=2)

                # cv2.imshow("frame",frame)

            out.write(frame)
        else:

            print("finish")
            break
コード例 #10
0
def main():
    eye_class_dict = {0: "open_eye", 1: "close_eye", 2: "other"}
    point_nums = 24
    threshold = [0.6, 0.7, 0.7]
    data_trans = Transforms.Compose([
        # Transforms.Resize((24, 48)),
        Transforms.ToTensor(),
        Transforms.Normalize((0.407, 0.405, 0.412), (0.087, 0.087, 0.087)),
    ])
    mixnet = MixNet(input_size=(24, 48), num_classes=3)
    weight_dict = torch.load("weight/change_mix_data_0202/Mixnet_epoch_59.pth")
    new_state_dict = OrderedDict()
    for k, v in weight_dict.items():
        name = k[7:]
        new_state_dict[name] = v

    mixnet.load_state_dict(new_state_dict)
    # stat(net,(3,48,48))
    mixnet.to('cuda:0')
    mixnet.eval()

    pnet, rnet, onet = create_mtcnn_net(
        p_model_path=r'model_store/final/pnet_epoch_19.pt',
        r_model_path=r'model_store/final/rnet_epoch_7.pt',
        o_model_path=r'model_store/final/onet_epoch_92.pt',
        use_cuda=True)
    mtcnn_detector = MtcnnDetector(pnet=pnet,
                                   rnet=rnet,
                                   onet=onet,
                                   min_face_size=24,
                                   threshold=threshold)
    videos_root_path = 'test_video/20200522164730261_0.avi'
    save_path_root = 'result_video/20200522164730261_0.avi'

    cap = cv2.VideoCapture(videos_root_path)
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    fps = cap.get(cv2.CAP_PROP_FPS)
    size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
            int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))

    out = cv2.VideoWriter(save_path_root, fourcc, fps, size)
    while True:
        ret, frame = cap.read()

        if ret:
            copy_frame = frame.copy()
            left_right_eye = []
            bboxs, landmarks, wearmask = mtcnn_detector.detect_face(frame,
                                                                    rgb=True)

            if landmarks is not None:
                for i in range(landmarks.shape[0]):
                    landmarks_one = landmarks[i, :]
                    landmarks_one = landmarks_one.reshape((point_nums, 2))
                    left_eye = np.array(landmarks_one[[6, 8, 10, 11, 14], :])
                    xmin = np.min(left_eye[:, 0])
                    ymin = np.min(left_eye[:, 1])
                    xmax = np.max(left_eye[:, 0])
                    ymax = np.max(left_eye[:, 1])
                    left_right_eye.append([xmin, ymin, xmax, ymax])
                    # cv2.rectangle(frame,(int(xmin),int(ymin)),(int(xmax),int(ymax)),(0,255,0),2)

                    right_eye = np.array(landmarks_one[[7, 9, 12, 13, 15], :])
                    xmin = np.min(right_eye[:, 0])
                    ymin = np.min(right_eye[:, 1])
                    xmax = np.max(right_eye[:, 0])
                    ymax = np.max(right_eye[:, 1])
                    left_right_eye.append([xmin, ymin, xmax, ymax])
                    # cv2.rectangle(frame,(int(xmin),int(ymin)),(int(xmax),int(ymax)),(0,255,0),2)
                    for j in [*left_eye, *right_eye]:
                        cv2.circle(frame, (int(j[0]), int(j[1])), 2,
                                   (255, 0, 0), -1)

                crop_img = []
                for xmin, ymin, xmax, ymax in left_right_eye:
                    w, h = xmax - xmin, ymax - ymin
                    # 随机扩展大小0.05-0.15
                    k = 0.1
                    ratio = h / w
                    if ratio > 1:
                        ratio = ratio - 1
                        xmin -= (ratio / 2 * w + k * h)
                        ymin -= (k * h)
                        xmax += (ratio / 2 * w + k * h)
                        ymax += (k * h)

                    else:
                        ratio = w / h - 1
                        xmin -= (k * w)
                        ymin -= (ratio / 2 * h + k * w)
                        xmax += (k * w)
                        ymax += (ratio / 2 * h + k * w)
                    cv2.rectangle(frame, (int(xmin), int(ymin)),
                                  (int(xmax), int(ymax)), (0, 255, 255), 2)
                    temp_img = copy_frame[int(ymin):int(ymax),
                                          int(xmin):int(xmax)]
                    temp_img = cv2.resize(temp_img, (24, 24))
                    crop_img.append(temp_img)
                if len(crop_img) < 2:
                    out.write(frame)
                    continue
                compose_img = np.hstack((crop_img[0], crop_img[1]))
                compose_img = cv2.cvtColor(compose_img, cv2.COLOR_BGR2RGB)

                compose_img = Image.fromarray(compose_img)
                img = data_trans(compose_img)
                img = img.unsqueeze(0)
                with torch.no_grad():
                    outputs = mixnet(img.to('cuda:0'))
                    result = torch.max(outputs, 1)[1]
                    eye_class = eye_class_dict[result.item()]
                cv2.putText(frame,eye_class,(0,20),cv2.FONT_HERSHEY_COMPLEX,1.3,(255,0,255) \
                    if result.item() == 0 else (255,255,0),thickness=2)
            out.write(frame)
        else:
            print("finish")
            break