def main(): device = torch.device("cuda:0") # Data loading code print("Loading data") #dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True)) dataset_test, num_classes = get_dataset("coco_kp", "val", get_transform(train=False)) print("Creating data loaders") #train_sampler = torch.utils.data.RandomSampler(dataset) test_sampler = torch.utils.data.SequentialSampler(dataset_test) #train_batch_sampler = torch.utils.data.BatchSampler( # train_sampler, args.batch_size, drop_last=True) #data_loader = torch.utils.data.DataLoader( # dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, # collate_fn=utils.collate_fn) data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, sampler=test_sampler, num_workers=4, collate_fn=utils.collate_fn) print("Creating model") model = torchvision.models.detection.__dict__['keypointrcnn_resnet50_fpn']( num_classes=num_classes, pretrained=True) model.to(device) #checkpoint = torch.load(args.resume, map_location='cpu') #model_without_ddp.load_state_dict(checkpoint['model']) #optimizer.load_state_dict(checkpoint['optimizer']) #lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) model.eval() detect_threshold = 0.7 keypoint_score_threshold = 2 with torch.no_grad(): for i in range(20): img, _ = dataset_test[i] prediction = model([img.to(device)]) keypoints = prediction[0]['keypoints'].cpu().numpy() scores = prediction[0]['scores'].cpu().numpy() keypoints_scores = prediction[0]['keypoints_scores'].cpu().numpy() idx = np.where(scores > detect_threshold) keypoints = keypoints[idx] keypoints_scores = keypoints_scores[idx] for j in range(keypoints.shape[0]): for num in range(17): if keypoints_scores[j][num] < keypoint_score_threshold: keypoints[j][num] = [0, 0, 0] img = img.mul(255).permute(1, 2, 0).byte().numpy() plot_poses(img, keypoints, save_name='./result/' + str(i) + '.jpg')
def process_image(filepath, save_dir): original = cv2.imread(filepath) scale_outputs = [] for i in range(len(multiscale)): scale = multiscale[i] scale_img = get_multi_scale_img(img=original, scale=scale) if i == 0: img = scale_img[:, :, [2, 1, 0]] plt.imsave(os.path.join(save_dir, 'input_image.jpg'), img) imgs_batch = np.zeros( (batch_size, int(scale * height), int(scale * width), 3)) imgs_batch[0] = scale_img # make prediction one_scale_output = sess.run(outputs[i], feed_dict={tf_img[i]: imgs_batch}) scale_outputs.append([o[0] for o in one_scale_output]) sample_output = scale_outputs[0] for i in range(1, len(multiscale)): for j in range(len(sample_output)): sample_output[j] += scale_outputs[i][j] for j in range(len(sample_output)): sample_output[j] /= len(multiscale) # visualization print('Visualization image has been saved into ', save_dir) # Here is the output map for right shoulder #Rshoulder_map = sample_output[0][:,:,config.KEYPOINTS.index('Rshoulder')] #plt.imsave(save_path+'kp_map.jpg',overlay(img, Rshoulder_map, alpha=0.7)) # Gaussian filtering helps when there are multiple local maxima for the same keypoint. H = compute_heatmaps(kp_maps=sample_output[0], short_offsets=sample_output[1]) for i in range(17): H[:, :, i] = gaussian_filter(H[:, :, i], sigma=2) #plt.imsave(save_path+'heatmaps.jpg',H[:,:,config.KEYPOINTS.index('Rshoulder')]*10) # The heatmaps are computed using the short offsets predicted by the network # Here are the right shoulder offsets #visualize_short_offsets(offsets=sample_output[1], heatmaps=H, keypoint_id='Rshoulder', img=img, every=8,save_path=save_path) # The connections between keypoints are computed via the mid-range offsets. # We can visuzalize them as well; for example right shoulder -> right hip #visualize_mid_offsets(offsets= sample_output[2], heatmaps=H, from_kp='Rshoulder', to_kp='Rhip', img=img, every=8,save_path=save_path) # And we can see the reverse connection (Rhip -> Rshjoulder) as well # visualize_mid_offsets(offsets= sample_output[2], heatmaps=H, to_kp='Rshoulder', from_kp='Rhip', img=img, every=8,save_path=save_path) # We can use the heatmaps to compute the skeletons pred_kp = get_keypoints(H) pred_skels = group_skeletons(keypoints=pred_kp, mid_offsets=sample_output[2]) pred_skels = [skel for skel in pred_skels if (skel[:, 2] > 0).sum() > 4] print('Number of detected skeletons: {}'.format(len(pred_skels))) plot_poses(img, pred_skels, save_path=save_dir) # we can use the predicted skeletons along with the long-range offsets and binary segmentation mask to compute the instance masks. applied_mask = apply_mask(img, sample_output[4][:, :, 0] > 0.5, color=[255, 0, 0]) plt.imsave(os.path.join(save_dir, 'segmentation_mask.jpg'), applied_mask) mask = sample_output[4][:, :, 0] > 0.5 temp = np.zeros(img.shape, dtype=np.uint8) temp.fill(255) for c in range(3): temp[:, :, c] = np.where(mask == 1, 255, 0) plt.imsave(os.path.join(save_dir, 'mask_new_scale.jpg'), temp) temp = crop_to_original_size(original, temp) plt.imsave(os.path.join(save_dir, 'MASK.jpg'), temp) #visualize_long_offsets(offsets=sample_output[3], keypoint_id='Rshoulder', seg_mask=sample_output[4], img=img, every=8,save_path=save_path) if len(pred_kp) > 0: instance_masks = get_instance_masks(pred_skels, sample_output[-1][:, :, 0], sample_output[-2]) plot_instance_masks(instance_masks, img, save_path=save_dir)