class P_net_test():
    def __init__(self, config):
        self.config = config
        self.scale = config.scale
        self.anchors = cuda(get_anchors(5000))
        self.P_net = P_net()
        self.R_net = R_net()

    def get_res(self, x):
        x = cuda(x.float())
        x = x[None]
        x = x.permute(0, 3, 1, 2)

        tanchors = self.get_P_net_res(x)
        tanchors=tanchors[:50000]

        if tanchors.shape[0] == 0:
            return tanchors
        res = self.get_R_net_res(x, tanchors[..., :4])
        return res

    def get_R_net_res(self, x, tanchors):
        h, w = x.shape[2:]
        tanchors = bbox2square(tanchors[..., :4])
        if tanchors.shape[0] > 0:
            t, _ = tanchors.max(dim=0)
            t = t + 1
            t = t.long()
            n_w, n_h = t[2], t[3]
            x = F.pad(x, (0, max(n_w - w, 0), 0, max(n_h - h, 0)), mode='constant', value=127.5)

        roi = tanchors
        roi_inds = cuda(torch.zeros((roi.size()[0], 1)))
        roi = torch.cat([roi_inds, roi], dim=1)
        xx = roialign_24(x, roi)
        xx = (xx - 127.5) / 128.0
        R_net_logits, R_net_loc, R_net_landmarks = self.R_net(xx)
        # R_net_logits = R_net_logits.permute(0, 2, 3, 1).view(-1, 2)
        # R_net_loc = R_net_loc.permute(0, 2, 3, 1).view(-1, 4)
        # R_net_landmarks = R_net_landmarks.permute(0, 2, 3, 1).view(-1, 10)

        score = F.softmax(R_net_logits, dim=-1)[..., 1]

        inds = score >= self.config.R_net_conf_thresh
        if inds.sum() == 0:
            return cuda(torch.zeros((0, 5)))
        score = score[inds]
        net_loc = R_net_loc[inds]
        net_landmarks = R_net_landmarks[inds]
        tanchors = tanchors[inds]

        bboxes = loc2bbox(net_loc, tanchors)

        bboxes[..., slice(0, 4, 2)] = torch.clamp(bboxes[..., slice(0, 4, 2)], 0, w - 1)
        bboxes[..., slice(1, 4, 2)] = torch.clamp(bboxes[..., slice(1, 4, 2)], 0, h - 1)
        hw = bboxes[:, 2:4] - bboxes[:, :2]
        inds = hw >= self.config.roi_min_size[1]
        inds = inds.all(dim=1)
        bboxes = bboxes[inds]
        score = score[inds]
        if bboxes.shape[0] == 0:
            return cuda(torch.zeros((0, 5)))

        score, inds = score.sort(descending=True)
        bboxes = bboxes[inds]
        keep = _box_nms(bboxes, score, self.config.R_net_iou_thresh)

        bboxes = bboxes[keep]
        score = score[keep]
        score = score.view(-1, 1)

        return torch.cat([bboxes, score], dim=-1)

    def get_P_net_res(self, x):

        h, w = x.shape[2:]

        roi = cuda(torch.tensor([[0, 0, 0, w, h]]).float())
        i = 0
        all_bboxes = []
        all_score = []
        while True:
            n_h, n_w = int(h * self.scale ** i), int(w * self.scale ** i)
            if n_h < 12 or n_w < 12:
                break

            roialign = ROIAlign((n_h, n_w), 1 / 1., 2)
            xx = roialign(x, roi)

            # xx = F.interpolate(x, size=(n_h, n_w))

            a = np.ceil(n_h / 12.) * 12
            b = np.ceil(n_w / 12.) * 12
            a = int(a)
            b = int(b)
            xx = F.pad(xx, (0, b - n_w, 0, a - n_h), mode='constant', value=127.5)
            xx = (xx - 127.5) / 128.0

            P_net_logits, P_net_loc, P_net_landmarks = self.P_net(xx)

            map_H, map_W = P_net_logits.shape[2:]
            P_net_logits = P_net_logits.permute(0, 2, 3, 1).contiguous().view(-1, 2)
            P_net_loc = P_net_loc.permute(0, 2, 3, 1).contiguous().view(-1, 4)
            P_net_landmarks = P_net_landmarks.permute(0, 2, 3, 1).contiguous().view(-1, 10)
            anchors = self.anchors[:map_H, :map_W].contiguous().view(-1, 4) / self.scale ** i
            i += 1

            score = F.softmax(P_net_logits, dim=-1)[..., 1]
            inds = score >= self.config.P_net_conf_thresh
            if inds.sum() == 0:
                continue

            score = score[inds]
            P_net_loc = P_net_loc[inds]

            anchors = anchors[inds]
            bboxes = loc2bbox(P_net_loc, anchors)
            bboxes[..., slice(0, 4, 2)] = torch.clamp(bboxes[..., slice(0, 4, 2)], 0, w - 1)
            bboxes[..., slice(1, 4, 2)] = torch.clamp(bboxes[..., slice(1, 4, 2)], 0, h - 1)

            hw = bboxes[..., 2:4] - bboxes[..., :2]
            inds = hw >= self.config.roi_min_size[0]
            inds = inds.all(dim=-1)
            if inds.sum() == 0:
                continue

            bboxes = bboxes[inds]
            score = score[inds]

            score, inds = score.sort(descending=True)
            bboxes = bboxes[inds]
            keep = _box_nms(bboxes, score, 0.5)
            score = score[keep]
            bboxes = bboxes[keep]

            all_bboxes.append(bboxes)
            all_score.append(score)
        if len(all_bboxes) == 0:
            return cuda(torch.zeros((0, 5)))

        bboxes = torch.cat(all_bboxes, dim=0)
        score = torch.cat(all_score, dim=0)

        score, inds = score.sort(descending=True)
        bboxes = bboxes[inds]
        keep = _box_nms(bboxes, score, self.config.P_net_iou_thresh)
        bboxes = bboxes[keep]
        score = score[keep]

        return torch.cat([bboxes, score.view(-1, 1)], dim=1)

    def test(self, model_file):
        self.P_net.load_state_dict(torch.load(model_file[0], map_location='cpu'))
        self.P_net.eval()
        cuda(self.P_net)
        self.R_net.load_state_dict(torch.load(model_file[1], map_location='cpu'))
        self.R_net.eval()
        cuda(self.R_net)

        test_dir = joblib.load('/home/zhai/PycharmProjects/Demo35/MTCNN/data_process/wider.pkl')
        i = 0
        O_net_train = []
        print(len(test_dir))
        for im_file, bboxes in test_dir[:]:
            if bboxes.shape[0] == 0:
                continue
            i += 1
            # if i<4386:
            #     continue

            img = cv2.imread(im_file)
            h, w = img.shape[:2]

            print(datetime.now(), i, img.shape)
            t_img = img
            img = torch.tensor(img)
            with torch.no_grad():
                res = self.get_res(img)
            if res.shape[0] == 0:
                continue

            res = res.cpu()
            res = res.detach().numpy()
            print(res.shape)
            # draw_gt(t_img, res[:, :4])
            res = res[..., :4]
            print(res.shape)
            res = res.astype(np.float32)

            O_net_train.append([im_file, bboxes, res])
        joblib.dump(O_net_train, 'O_net_train_wider_1.pkl')
示例#2
0
class P_net_test():
    def __init__(self, config):
        self.config = config
        self.scale = config.scale
        self.anchors = cuda(get_anchors(5000))
        self.P_net = P_net()

    def get_res(self, x):
        x = cuda(x.float())
        x = x[None]
        x = x.permute(0, 3, 1, 2)
        res = self.get_P_net_res(x)

        return res

    def get_P_net_res(self, x):

        h, w = x.shape[2:]

        roi = cuda(torch.tensor([[0, 0, 0, w, h]]).float())
        i = 0
        all_bboxes = []
        all_score = []
        while True:
            n_h, n_w = int(h * self.scale**i), int(w * self.scale**i)
            if n_h < 12 or n_w < 12:
                break

            roialign = ROIAlign((n_h, n_w), 1 / 1., 2)
            xx = roialign(x, roi)

            # xx = F.interpolate(x, size=(n_h, n_w))

            a = np.ceil(n_h / 12.) * 12
            b = np.ceil(n_w / 12.) * 12
            a = int(a)
            b = int(b)
            xx = F.pad(xx, (0, b - n_w, 0, a - n_h),
                       mode='constant',
                       value=127.5)
            xx = (xx - 127.5) / 128.0

            P_net_logits, P_net_loc, P_net_landmarks = self.P_net(xx)

            map_H, map_W = P_net_logits.shape[2:]
            P_net_logits = P_net_logits.permute(0, 2, 3,
                                                1).contiguous().view(-1, 2)
            P_net_loc = P_net_loc.permute(0, 2, 3, 1).contiguous().view(-1, 4)
            P_net_landmarks = P_net_landmarks.permute(0, 2, 3,
                                                      1).contiguous().view(
                                                          -1, 10)
            anchors = self.anchors[:map_H, :map_W].contiguous().view(
                -1, 4) / self.scale**i
            i += 1

            score = F.softmax(P_net_logits, dim=-1)[..., 1]
            inds = score >= self.config.P_net_conf_thresh
            if inds.sum() == 0:
                continue

            score = score[inds]
            P_net_loc = P_net_loc[inds]

            anchors = anchors[inds]
            bboxes = loc2bbox(P_net_loc, anchors)
            bboxes[...,
                   slice(0, 4, 2)] = torch.clamp(bboxes[...,
                                                        slice(0, 4, 2)], 0,
                                                 w - 1)
            bboxes[...,
                   slice(1, 4, 2)] = torch.clamp(bboxes[...,
                                                        slice(1, 4, 2)], 0,
                                                 h - 1)

            hw = bboxes[..., 2:4] - bboxes[..., :2]
            inds = hw >= self.config.roi_min_size[0]
            inds = inds.all(dim=-1)
            if inds.sum() == 0:
                continue

            bboxes = bboxes[inds]
            score = score[inds]

            score, inds = score.sort(descending=True)
            bboxes = bboxes[inds]
            keep = _box_nms(bboxes, score, 0.5)
            score = score[keep]
            bboxes = bboxes[keep]

            all_bboxes.append(bboxes)
            all_score.append(score)
        if len(all_bboxes) == 0:
            return cuda(torch.zeros((0, 5)))

        bboxes = torch.cat(all_bboxes, dim=0)
        score = torch.cat(all_score, dim=0)

        score, inds = score.sort(descending=True)
        bboxes = bboxes[inds]
        keep = _box_nms(bboxes, score, self.config.P_net_iou_thresh)
        bboxes = bboxes[keep]
        score = score[keep]

        return torch.cat([bboxes, score.view(-1, 1)], dim=1)

    def test(self, model_file):
        self.P_net.load_state_dict(torch.load(model_file, map_location='cpu'))
        self.P_net.eval()
        cuda(self.P_net)
        ellips = joblib.load(
            '/home/zhai/PycharmProjects/Demo35/MTCNN/data_process/ellips.pkl')
        test_dir = joblib.load(
            '/home/zhai/PycharmProjects/Demo35/MTCNN/data_process/FDDB_test.pkl'
        )
        path = '/home/zhai/PycharmProjects/Demo35/dataset/FDDB/'
        z = 0
        for i in range(len(test_dir)):
            w_file = '/home/zhai/PycharmProjects/Demo35/dataset/FDDB/res/fold-0%d-out.txt' % (
                i + 1)
            if i == 9:
                w_file = '/home/zhai/PycharmProjects/Demo35/dataset/FDDB/res/fold-%d-out.txt' % (
                    i + 1)
            with codecs.open(w_file, 'w') as f:
                for file in test_dir[i]:
                    z += 1

                    im_file = path + file + '.jpg'
                    img = cv2.imread(im_file)

                    t_img = img
                    img = torch.tensor(img)
                    with torch.no_grad():
                        res = self.get_res(img)

                    res = res.cpu()
                    res = res.detach().numpy()

                    # draw_gt(t_img, res[..., :4], ellips[file])

                    m = res.shape[0]
                    print(datetime.now(), z, m)

                    f.write(file + '\n')
                    f.write(str(m) + '\n')
                    res[..., 2:4] = res[..., 2:4] - res[..., :2]
                    for bbox in res:
                        f.write(
                            str(bbox[0]) + ' ' + str(bbox[1]) + ' ' +
                            str(bbox[2]) + ' ' + str(bbox[3]) + ' ' +
                            str(bbox[4]) + '\n')