コード例 #1
0
def main():

    device = torch.device("cuda:0")

    # Data loading code
    print("Loading data")

    #dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True))
    dataset_test, num_classes = get_dataset("coco_kp", "val",
                                            get_transform(train=False))

    print("Creating data loaders")

    #train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    #train_batch_sampler = torch.utils.data.BatchSampler(
    #   train_sampler, args.batch_size, drop_last=True)

    #data_loader = torch.utils.data.DataLoader(
    #    dataset, batch_sampler=train_batch_sampler, num_workers=args.workers,
    #    collate_fn=utils.collate_fn)

    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=1,
                                                   sampler=test_sampler,
                                                   num_workers=4,
                                                   collate_fn=utils.collate_fn)

    print("Creating model")
    model = torchvision.models.detection.__dict__['keypointrcnn_resnet50_fpn'](
        num_classes=num_classes, pretrained=True)
    model.to(device)

    #checkpoint = torch.load(args.resume, map_location='cpu')
    #model_without_ddp.load_state_dict(checkpoint['model'])
    #optimizer.load_state_dict(checkpoint['optimizer'])
    #lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])

    model.eval()

    detect_threshold = 0.7
    keypoint_score_threshold = 2
    with torch.no_grad():
        for i in range(20):
            img, _ = dataset_test[i]
            prediction = model([img.to(device)])
            keypoints = prediction[0]['keypoints'].cpu().numpy()
            scores = prediction[0]['scores'].cpu().numpy()
            keypoints_scores = prediction[0]['keypoints_scores'].cpu().numpy()
            idx = np.where(scores > detect_threshold)
            keypoints = keypoints[idx]
            keypoints_scores = keypoints_scores[idx]
            for j in range(keypoints.shape[0]):
                for num in range(17):
                    if keypoints_scores[j][num] < keypoint_score_threshold:
                        keypoints[j][num] = [0, 0, 0]
            img = img.mul(255).permute(1, 2, 0).byte().numpy()
            plot_poses(img, keypoints, save_name='./result/' + str(i) + '.jpg')
コード例 #2
0
def process_image(filepath, save_dir):
    original = cv2.imread(filepath)
    scale_outputs = []
    for i in range(len(multiscale)):
        scale = multiscale[i]
        scale_img = get_multi_scale_img(img=original, scale=scale)
        if i == 0:
            img = scale_img[:, :, [2, 1, 0]]
            plt.imsave(os.path.join(save_dir, 'input_image.jpg'), img)
        imgs_batch = np.zeros(
            (batch_size, int(scale * height), int(scale * width), 3))
        imgs_batch[0] = scale_img

        # make prediction
        one_scale_output = sess.run(outputs[i],
                                    feed_dict={tf_img[i]: imgs_batch})
        scale_outputs.append([o[0] for o in one_scale_output])

    sample_output = scale_outputs[0]
    for i in range(1, len(multiscale)):
        for j in range(len(sample_output)):
            sample_output[j] += scale_outputs[i][j]
    for j in range(len(sample_output)):
        sample_output[j] /= len(multiscale)

    # visualization
    print('Visualization image has been saved into ', save_dir)

    # Here is the output map for right shoulder
    #Rshoulder_map = sample_output[0][:,:,config.KEYPOINTS.index('Rshoulder')]
    #plt.imsave(save_path+'kp_map.jpg',overlay(img, Rshoulder_map, alpha=0.7))

    # Gaussian filtering helps when there are multiple local maxima for the same keypoint.
    H = compute_heatmaps(kp_maps=sample_output[0],
                         short_offsets=sample_output[1])
    for i in range(17):
        H[:, :, i] = gaussian_filter(H[:, :, i], sigma=2)
    #plt.imsave(save_path+'heatmaps.jpg',H[:,:,config.KEYPOINTS.index('Rshoulder')]*10)

    # The heatmaps are computed using the short offsets predicted by the network
    # Here are the right shoulder offsets
    #visualize_short_offsets(offsets=sample_output[1], heatmaps=H, keypoint_id='Rshoulder', img=img, every=8,save_path=save_path)

    # The connections between keypoints are computed via the mid-range offsets.
    # We can visuzalize them as well; for example right shoulder -> right hip
    #visualize_mid_offsets(offsets= sample_output[2], heatmaps=H, from_kp='Rshoulder', to_kp='Rhip', img=img, every=8,save_path=save_path)

    # And we can see the reverse connection (Rhip -> Rshjoulder) as well
    # visualize_mid_offsets(offsets= sample_output[2], heatmaps=H, to_kp='Rshoulder', from_kp='Rhip', img=img, every=8,save_path=save_path)

    # We can use the heatmaps to compute the skeletons
    pred_kp = get_keypoints(H)
    pred_skels = group_skeletons(keypoints=pred_kp,
                                 mid_offsets=sample_output[2])
    pred_skels = [skel for skel in pred_skels if (skel[:, 2] > 0).sum() > 4]
    print('Number of detected skeletons: {}'.format(len(pred_skels)))

    plot_poses(img, pred_skels, save_path=save_dir)

    # we can use the predicted skeletons along with the long-range offsets and binary segmentation mask to compute the instance masks.
    applied_mask = apply_mask(img,
                              sample_output[4][:, :, 0] > 0.5,
                              color=[255, 0, 0])
    plt.imsave(os.path.join(save_dir, 'segmentation_mask.jpg'), applied_mask)

    mask = sample_output[4][:, :, 0] > 0.5
    temp = np.zeros(img.shape, dtype=np.uint8)
    temp.fill(255)
    for c in range(3):
        temp[:, :, c] = np.where(mask == 1, 255, 0)

    plt.imsave(os.path.join(save_dir, 'mask_new_scale.jpg'), temp)
    temp = crop_to_original_size(original, temp)
    plt.imsave(os.path.join(save_dir, 'MASK.jpg'), temp)

    #visualize_long_offsets(offsets=sample_output[3], keypoint_id='Rshoulder', seg_mask=sample_output[4], img=img, every=8,save_path=save_path)
    if len(pred_kp) > 0:
        instance_masks = get_instance_masks(pred_skels, sample_output[-1][:, :,
                                                                          0],
                                            sample_output[-2])
        plot_instance_masks(instance_masks, img, save_path=save_dir)