コード例 #1
0
ファイル: keypoints.py プロジェクト: zxt881108/pytorch-cv
def plot_keypoints(img, coords, confidence, class_ids, bboxes, scores,
                   box_thresh=0.5, keypoint_thresh=0.2, **kwargs):
    """Visualize keypoints.

    Parameters
    ----------
    img : numpy.ndarray or torch.Tensor
        Image with shape `H, W, 3`.
    coords : numpy.ndarray or torch.Tensor
        Array with shape `Batch, N_Joints, 2`.
    confidence : numpy.ndarray or torch.Tensor
        Array with shape `Batch, N_Joints, 1`.
    class_ids : numpy.ndarray or torch.Tensor
        Class IDs.
    bboxes : numpy.ndarray or torch.Tensor
        Bounding boxes with shape `N, 4`. Where `N` is the number of boxes.
    scores : numpy.ndarray or torch.Tensor, optional
        Confidence scores of the provided `bboxes` with shape `N`.
    box_thresh : float, optional, default 0.5
        Display threshold if `scores` is provided. Scores with less than `box_thresh`
        will be ignored in display.
    keypoint_thresh : float, optional, default 0.2
        Keypoints with confidence less than `keypoint_thresh` will be ignored in display.

    Returns
    -------
    matplotlib axes
        The ploted axes.

    """
    if isinstance(coords, torch.Tensor):
        coords = coords.cpu().numpy()
    if isinstance(class_ids, torch.Tensor):
        class_ids = class_ids.cpu().numpy()
    if isinstance(bboxes, torch.Tensor):
        bboxes = bboxes.cpu().numpy()
    if isinstance(scores, torch.Tensor):
        scores = scores.cpu().numpy()
    if isinstance(confidence, torch.Tensor):
        confidence = confidence.cpu().numpy()

    joint_visible = confidence[:, :, 0] > keypoint_thresh
    joint_pairs = [[0, 1], [1, 3], [0, 2], [2, 4],
                   [5, 6], [5, 7], [7, 9], [6, 8], [8, 10],
                   [5, 11], [6, 12], [11, 12],
                   [11, 13], [12, 14], [13, 15], [14, 16]]

    person_ind = class_ids[0] == 0
    ax = plot_bbox(img, bboxes[0][person_ind[:, 0]],
                   scores[0][person_ind[:, 0]], thresh=box_thresh, **kwargs)

    colormap_index = np.linspace(0, 1, len(joint_pairs))
    for i in range(coords.shape[0]):
        pts = coords[i]
        for cm_ind, jp in zip(colormap_index, joint_pairs):
            if joint_visible[i, jp[0]] and joint_visible[i, jp[1]]:
                ax.plot(pts[jp, 0], pts[jp, 1],
                        linewidth=3.0, alpha=0.7, color=plt.cm.cool(cm_ind))
                ax.scatter(pts[jp, 0], pts[jp, 1], s=20)
    return ax
コード例 #2
0
        default=0.5,
        help='Threshold of object score when visualize the bboxes.')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    device = torch.device('cpu')
    if args.cuda:
        device = torch.device('cuda')
    image = args.images
    net = get_model(args.network, pretrained=True, root=args.root)
    net.to(device)
    net.set_nms(0.45, 200)
    net.eval()

    ax = None
    x, img = load_test(image, short=512)
    x = x.to(device)
    with torch.no_grad():
        ids, scores, bboxes = [xx[0].cpu().numpy() for xx in net(x)]
    ax = plot_bbox(img,
                   bboxes,
                   scores,
                   ids,
                   thresh=args.thresh,
                   class_names=net.classes,
                   ax=ax)
    plt.show()
コード例 #3
0
                        help='Default pre-trained mdoel root.')
    parser.add_argument('--pretrained', type=str, default='True',
                        help='Load weights from previously saved parameters.')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    device = torch.device('cpu')
    if args.cuda:
        device = torch.device('cuda')
    image = args.images
    net = get_model(args.network, pretrained=True)
    net.to(device)
    net.set_nms(0.3, 200)
    net.eval()

    ax = None
    x, img = load_test(image, short=net.short, max_size=net.max_size)
    x = x.to(device)
    with torch.no_grad():
        ids, scores, bboxes, masks = [xx.cpu().numpy() for xx in net(x)]
    masks = expand_mask(masks, bboxes, (img.shape[1], img.shape[0]), scores)
    img = plot_mask(img, masks)
    fig = plt.figure(figsize=(15, 15))
    ax = fig.add_subplot(1, 1, 1)
    ax = plot_bbox(img, bboxes, scores, ids,
                   class_names=net.classes, ax=ax)
    plt.show()