Example #1
0
def main():
    # Construct the db
    db_config = SpartanSupvervisedKeypointDBConfig()
    db_config.keypoint_yaml_name = 'shoe_6_keypoint_image.yaml'
    db_config.pdc_data_root = '/home/wei/data/pdc'
    db_config.config_file_path = '/home/wei/Coding/mankey/config/shoe_logs.txt'
    database = SpartanSupervisedKeypointDatabase(db_config)

    # Construct torch dataset
    config = SupervisedKeypointDatasetConfig()
    config.network_in_patch_height = 256
    config.network_in_patch_width = 256
    config.network_out_map_width = 64
    config.network_out_map_height = 64
    config.image_database_list.append(database)
    config.is_train = False
    dataset = SupervisedKeypointDataset(config)

    # The counter
    min_keypoint_depth: float = 10000
    max_keypoint_depth: float = 0
    avg_keypoint_depth: float = 0
    counter = 0

    # Iterate
    for i in range(len(dataset)):
        processed_entry = dataset.get_processed_entry(dataset.entry_list[i])
        n_keypoint = processed_entry.keypoint_xy_depth.shape[1]
        for j in range(n_keypoint):
            min_keypoint_depth = min(min_keypoint_depth,
                                     processed_entry.keypoint_xy_depth[2, j])
            max_keypoint_depth = max(max_keypoint_depth,
                                     processed_entry.keypoint_xy_depth[2, j])
            avg_keypoint_depth += processed_entry.keypoint_xy_depth[2, j]
            counter += 1

    # Some output
    print('The min value is ', min_keypoint_depth)
    print('The max value is ', max_keypoint_depth)
    print('The average value is ', avg_keypoint_depth / float(counter))
Example #2
0
def visualize_entry_nostage(entry_idx: int, network: torch.nn.Module,
                            dataset: SupervisedKeypointDataset,
                            config: SupervisedKeypointDatasetConfig,
                            save_dir: str):
    # The raw input
    processed_entry = dataset.get_processed_entry(
        dataset.entry_list[entry_idx])

    # The processed input
    stacked_rgbd = dataset[entry_idx][parameter.rgbd_image_key]
    normalized_xy_depth = dataset[entry_idx][parameter.keypoint_xyd_key]

    stacked_rgbd = torch.from_numpy(stacked_rgbd)
    stacked_rgbd = torch.unsqueeze(stacked_rgbd, dim=0)
    stacked_rgbd = stacked_rgbd.cuda()

    # Do forward
    raw_pred = network(stacked_rgbd)
    prob_pred = raw_pred[:, 0:dataset.num_keypoints, :, :]
    depthmap_pred = raw_pred[:, dataset.num_keypoints:, :, :]
    heatmap = predict.heatmap_from_predict(prob_pred, dataset.num_keypoints)
    coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(
        heatmap, dataset.num_keypoints)
    depth_pred = predict.depth_integration(heatmap, depthmap_pred)

    # To actual image coord
    coord_x = coord_x.cpu().detach().numpy()
    coord_y = coord_y.cpu().detach().numpy()
    coord_x = (coord_x + 0.5) * config.network_in_patch_width
    coord_y = (coord_y + 0.5) * config.network_in_patch_height

    # To actual depth value
    depth_pred = depth_pred.cpu().detach().numpy()
    depth_pred = (depth_pred *
                  config.depth_image_scale) + config.depth_image_mean

    # Combine them
    keypointxy_depth_pred = np.zeros((3, dataset.num_keypoints), dtype=np.int)
    keypointxy_depth_pred[0, :] = coord_x[0, :, 0].astype(np.int)
    keypointxy_depth_pred[1, :] = coord_y[0, :, 0].astype(np.int)
    keypointxy_depth_pred[2, :] = depth_pred[0, :, 0].astype(np.int)

    # Get the image
    from mankey.utils.imgproc import draw_image_keypoint, draw_visible_heatmap
    keypoint_rgb_cv = draw_image_keypoint(processed_entry.cropped_rgb,
                                          keypointxy_depth_pred,
                                          processed_entry.keypoint_validity)
    rgb_save_path = os.path.join(save_dir, 'image_%d_rgb.png' % entry_idx)
    cv2.imwrite(rgb_save_path, keypoint_rgb_cv)

    # The depth error
    depth_error_mm = np.abs(processed_entry.keypoint_xy_depth[2, :] -
                            keypointxy_depth_pred[2, :])
    max_depth_error = np.max(depth_error_mm)
    print('Entry %d' % entry_idx)
    print('The max depth error (mm) is ', max_depth_error)

    # The pixel error
    pixel_error = np.sum(np.sqrt((processed_entry.keypoint_xy_depth[0:2, :] -
                                  keypointxy_depth_pred[0:2, :])**2),
                         axis=0)
    max_pixel_error = np.max(pixel_error)
    print('The max pixel error (pixel in 256x256 image) is ', max_pixel_error)