Example #1
0
    def __init__(self, dataset):

        self.dataset = dataset

        # Label Clusters
        self.label_cluster_utils = LabelClusterUtils(self.dataset)

        self.clusters, self.std_devs = [None, None]

        # BEV source from dataset config
        self.bev_source = self.dataset.bev_source

        # Parse config
        self.config = dataset.config.kitti_utils_config
        self.area_extents = np.reshape(self.config.area_extents, (3, 2))
        self.bev_extents = self.area_extents[[0, 2]]
        self.voxel_size = self.config.voxel_size
        self.anchor_strides = np.reshape(self.config.anchor_strides, (-1, 2))

        self._density_threshold = self.config.density_threshold

        # Mini Batch Utils
        self.mini_batch_utils = MiniBatchUtils(self.dataset)
        self._mini_batch_dir = self.mini_batch_utils.mini_batch_dir

        # Label Clusters
        self.clusters, self.std_devs = \
            self.label_cluster_utils.get_clusters()
    def test_get_clusters(self):

        # classes = ['Car', 'Pedestrian', 'Cyclist']
        num_clusters = [2, 1, 1]

        label_cluster_utils = LabelClusterUtils(self.dataset)
        clusters, std_devs = label_cluster_utils.get_clusters()

        # Check that correct number of clusters are returned
        clusters_per_class = [len(cls_clusters) for cls_clusters in clusters]
        std_devs_per_class = [len(cls_std_devs) for cls_std_devs in std_devs]

        self.assertEqual(clusters_per_class, num_clusters)
        self.assertEqual(std_devs_per_class, num_clusters)

        # Check that text files were saved
        txt_folder_exists = os.path.isdir(
            avod.root_dir() + "/data/label_clusters/unittest-kitti")
        self.assertTrue(txt_folder_exists)

        # Calling get_clusters again should read from files
        read_clusters, read_std_devs = label_cluster_utils.get_clusters()

        # Check that read values are the same as generated ones
        np.testing.assert_allclose(np.vstack(clusters),
                                   np.vstack(read_clusters))
        np.testing.assert_allclose(np.vstack(std_devs),
                                   np.vstack(read_std_devs))
    def test_flatten_data(self):
        data_to_reshape = list()

        data_to_reshape.append([[1, 2, 3], [4, 5, 6]])
        data_to_reshape.append([[7, 8, 9]])
        data_to_reshape.append([[10, 11, 12], [13, 14, 15]])

        expected_output = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9],
                                    [10, 11, 12], [13, 14, 15]])

        label_cluster_utils = LabelClusterUtils(self.dataset)

        flattened = label_cluster_utils._flatten_data(data_to_reshape)
        np.testing.assert_array_equal(flattened,
                                      expected_output,
                                      err_msg='Wrong flattened array')
Example #4
0
    def __init__(self, dataset):

        self.dataset = dataset

        # Label Clusters
        self.label_cluster_utils = LabelClusterUtils(self.dataset)

        self.clusters, self.std_devs = [None, None]

        # BEV source from dataset config
        self.bev_source = self.dataset.bev_source

        # Parse config
        self.config = dataset.config.kitti_utils_config
        self.area_extents = np.reshape(self.config.area_extents, (3, 2))
        self.bev_extents = self.area_extents[[0, 2]]
        self.voxel_size = self.config.voxel_size
        self.anchor_strides = np.reshape(self.config.anchor_strides, (-1, 2))

        self.bev_generator = bev_generator_builder.build(
            self.config.bev_generator, self)

        self._density_threshold = self.config.density_threshold

        # Check that depth maps folder exists

        if self.bev_source == 'depth' and \
                not os.path.exists(self.dataset.depth_dir):
            raise FileNotFoundError(
                'Could not find depth maps, please run '
                'demos/save_lidar_depth_maps.py in wavedata first')

        # Mini Batch Utils
        self.mini_batch_utils = MiniBatchUtils(self.dataset)
        self._mini_batch_dir = self.mini_batch_utils.mini_batch_dir

        # Label Clusters
        self.clusters, self.std_devs = \
            self.label_cluster_utils.get_clusters()
Example #5
0
class KittiUtils(object):
    # Definition for difficulty levels
    # These values are from Kitti dataset
    # 0 - easy, 1 - medium, 2 - hard
    HEIGHT = (40, 25, 25)
    OCCLUSION = (0, 1, 2)
    TRUNCATION = (0.15, 0.3, 0.5)

    def __init__(self, dataset):

        self.dataset = dataset

        # Label Clusters
        self.label_cluster_utils = LabelClusterUtils(self.dataset)

        self.clusters, self.std_devs = [None, None]

        # BEV source from dataset config
        self.bev_source = self.dataset.bev_source

        # Parse config
        self.config = dataset.config.kitti_utils_config
        self.area_extents = np.reshape(self.config.area_extents, (3, 2))
        self.bev_extents = self.area_extents[[0, 2]]
        self.voxel_size = self.config.voxel_size
        self.anchor_strides = np.reshape(self.config.anchor_strides, (-1, 2))

        self._density_threshold = self.config.density_threshold

        # Mini Batch Utils
        self.mini_batch_utils = MiniBatchUtils(self.dataset)
        self._mini_batch_dir = self.mini_batch_utils.mini_batch_dir

        # Label Clusters
        self.clusters, self.std_devs = \
            self.label_cluster_utils.get_clusters()

    def class_str_to_index(self, class_str):
        """
        Converts an object class type string into a integer index

        Args:
            class_str: the object type (e.g. 'Car', 'Pedestrian', or 'Cyclist')

        Returns:
            The corresponding integer index for a class type, starting at 1
            (0 is reserved for the background class).
            Returns -1 if we don't care about that class type.
        """
        if class_str in self.dataset.classes:
            return self.dataset.classes.index(class_str) + 1

        raise ValueError('Invalid class string {}, not in {}'.format(
            class_str, self.dataset.classes))

    def create_slice_filter(self, point_cloud, area_extents,
                            ground_plane, ground_offset_dist, offset_dist):
        """ Creates a slice filter to take a slice of the point cloud between
            ground_offset_dist and offset_dist above the ground plane

        Args:
            point_cloud: Point cloud in the shape (3, N)
            area_extents: 3D area extents
            ground_plane: ground plane coefficients
            offset_dist: max distance above the ground
            ground_offset_dist: min distance above the ground plane

        Returns:
            A boolean mask if shape (N,) where
                True indicates the point should be kept
                False indicates the point should be removed
        """

        # Filter points within certain xyz range and offset from ground plane
        offset_filter = obj_utils.get_point_filter(point_cloud, area_extents,
                                                   ground_plane, offset_dist)

        # Filter points within 0.2m of the road plane
        road_filter = obj_utils.get_point_filter(point_cloud, area_extents,
                                                 ground_plane,
                                                 ground_offset_dist)

        slice_filter = np.logical_xor(offset_filter, road_filter)
        return slice_filter
   
    def get_anchors_info(self, classes_name, anchor_strides, sample_name):

        anchors_info = self.mini_batch_utils.get_anchors_info(classes_name,
                                                              anchor_strides,
                                                              sample_name)
        return anchors_info

    def get_ground_plane(self, sample_name):
        """Reads the ground plane for the sample

        Args:
            sample_name: name of the sample, e.g. '000123'

        Returns:
            ground_plane: ground plane coefficients
        """
        ground_plane = obj_utils.get_road_plane(int(sample_name),
                                                self.dataset.planes_dir)
        return ground_plane

    def filter_labels(self, objects,
                      classes=None,
                      difficulty=None,
                      max_occlusion=None):
        """Filters ground truth labels based on class, difficulty, and
        maximum occlusion

        Args:
            objects: A list of ground truth instances of Object Label
            classes: (optional) classes to filter by, if None
                all classes are used
            difficulty: (optional) KITTI difficulty rating as integer
            max_occlusion: (optional) maximum occlusion to filter objects

        Returns:
            filtered object label list
        """
        if classes is None:
            classes = self.dataset.classes

        objects = np.asanyarray(objects)
        filter_mask = np.ones(len(objects), dtype=np.bool)

        for obj_idx in range(len(objects)):
            obj = objects[obj_idx]

            if filter_mask[obj_idx]:
                if not self._check_class(obj, classes):
                    filter_mask[obj_idx] = False
                    continue

            # Filter by difficulty (occlusion, truncation, and height)
            if difficulty is not None and \
                    not self._check_difficulty(obj, difficulty):
                filter_mask[obj_idx] = False
                continue

            if max_occlusion and \
                    obj.occlusion > max_occlusion:
                filter_mask[obj_idx] = False
                continue

        return objects[filter_mask]

    def _check_difficulty(self, obj, difficulty):
        """This filters an object by difficulty.
        Args:
            obj: An instance of ground-truth Object Label
            difficulty: An int defining the KITTI difficulty rate
        Returns: True or False depending on whether the object
            matches the difficulty criteria.
        """

        return ((obj.occlusion <= self.OCCLUSION[difficulty]) and
                (obj.truncation <= self.TRUNCATION[difficulty]) and
                (obj.y2 - obj.y1) >= self.HEIGHT[difficulty])

    def _check_class(self, obj, classes):
        """This filters an object by class.
        Args:
            obj: An instance of ground-truth Object Label
        Returns: True or False depending on whether the object
            matches the desired class.
        """
        return obj.type in classes
class KittiUtils(object):
    # Definition for difficulty levels
    # These values are from Kitti dataset
    # 0 - easy, 1 - medium, 2 - hard
    HEIGHT = (40, 25, 25)
    OCCLUSION = (0, 1, 2)
    TRUNCATION = (0.15, 0.3, 0.5)

    def __init__(self, dataset):

        self.dataset = dataset

        # Label Clusters
        self.label_cluster_utils = LabelClusterUtils(self.dataset)

        self.clusters, self.std_devs = [None, None]

        # BEV source from dataset config
        self.bev_source = self.dataset.bev_source

        # Parse config
        self.config = dataset.config.kitti_utils_config
        self.area_extents = np.reshape(self.config.area_extents, (3, 2))
        self.bev_extents = self.area_extents[[0, 2]]
        self.voxel_size = self.config.voxel_size
        self.anchor_strides = np.reshape(self.config.anchor_strides, (-1, 2))

        self.bev_generator = bev_generator_builder.build(
            self.config.bev_generator, self)

        self._density_threshold = self.config.density_threshold

        # Check that depth maps folder exists
        if self.bev_source == 'depth' and \
                not os.path.exists(self.dataset.depth_dir):
            raise FileNotFoundError(
                'Could not find depth maps, please run '
                'demos/save_lidar_depth_maps.py in wavedata first')

        # Mini Batch Utils
        self.mini_batch_utils = MiniBatchUtils(self.dataset)
        self._mini_batch_dir = self.mini_batch_utils.mini_batch_dir

        # Label Clusters
        self.clusters, self.std_devs = \
            self.label_cluster_utils.get_clusters()

    def class_str_to_index(self, class_str):
        """
        Converts an object class type string into a integer index

        Args:
            class_str: the object type (e.g. 'Car', 'Pedestrian', or 'Cyclist')

        Returns:
            The corresponding integer index for a class type, starting at 1
            (0 is reserved for the background class).
            Returns -1 if we don't care about that class type.
        """
        if class_str in self.dataset.classes:
            return self.dataset.classes.index(class_str) + 1

        raise ValueError('Invalid class string {}, not in {}'.format(
            class_str, self.dataset.classes))

    def create_slice_filter(self, point_cloud, area_extents, ground_plane,
                            ground_offset_dist, offset_dist):
        """ Creates a slice filter to take a slice of the point cloud between
            ground_offset_dist and offset_dist above the ground plane

        Args:
            point_cloud: Point cloud in the shape (3, N)
            area_extents: 3D area extents
            ground_plane: ground plane coefficients
            offset_dist: max distance above the ground
            ground_offset_dist: min distance above the ground plane

        Returns:
            A boolean mask if shape (N,) where
                True indicates the point should be kept
                False indicates the point should be removed
        """

        # Filter points within certain xyz range and offset from ground plane
        offset_filter = obj_utils.get_point_filter(point_cloud, area_extents,
                                                   ground_plane, offset_dist)

        # Filter points within 0.2m of the road plane
        road_filter = obj_utils.get_point_filter(point_cloud, area_extents,
                                                 ground_plane,
                                                 ground_offset_dist)

        slice_filter = np.logical_xor(offset_filter, road_filter)
        return slice_filter

    def create_bev_maps(self, point_cloud, ground_plane, output_indices=False):
        """ Calculates bev maps

        Args:
            point_cloud: point cloud
            ground_plane: ground_plane coefficients

        Returns:
            Dictionary with entries for each type of map (e.g. height, density)
        """

        bev_maps = self.bev_generator.generate_bev(
            self.bev_source,
            point_cloud,
            ground_plane,
            self.area_extents,
            self.voxel_size,
            output_indices=output_indices)

        return bev_maps

    def get_anchors_info(self, classes_name, anchor_strides, sample_name):

        anchors_info = self.mini_batch_utils.get_anchors_info(
            classes_name, anchor_strides, sample_name)
        return anchors_info

    def get_point_cloud(self, source, img_idx, image_shape=None):
        """ Gets the points from the point cloud for a particular image,
            keeping only the points within the area extents, and takes a slice
            between self._ground_filter_offset and self._offset_distance above
            the ground plane

        Args:
            source: point cloud source, e.g. 'lidar'
            img_idx: An integer sample image index, e.g. 123 or 500
            image_shape: image dimensions (h, w), only required when
                source is 'lidar' or 'depth'

        Returns:
            The set of points in the shape (N, 3)
        """

        if source == 'lidar':
            # wavedata wants im_size in (w, h) order
            im_size = [image_shape[1], image_shape[0]]

            point_cloud = obj_utils.get_lidar_point_cloud(
                img_idx,
                self.dataset.calib_dir,
                self.dataset.velo_dir,
                im_size=im_size)

        else:
            raise ValueError("Invalid source {}".format(source))

        return point_cloud

    def get_ground_plane(self, sample_name):
        """Reads the ground plane for the sample

        Args:
            sample_name: name of the sample, e.g. '000123'

        Returns:
            ground_plane: ground plane coefficients
        """
        ground_plane = obj_utils.get_road_plane(int(sample_name),
                                                self.dataset.planes_dir)
        return ground_plane

    def _apply_offset_filter(self, point_cloud, ground_plane, offset_dist=2.0):
        """ Applies an offset filter to the point cloud

        Args:
            point_cloud: A point cloud in the shape (3, N)
            ground_plane: ground plane coefficients,
                if None, will only filter to the area extents
            offset_dist: (optional) height above ground plane for filtering

        Returns:
            Points filtered with an offset filter in the shape (N, 3)
        """
        offset_filter = obj_utils.get_point_filter(point_cloud,
                                                   self.area_extents,
                                                   ground_plane, offset_dist)

        # Transpose point cloud into N x 3 points
        points = np.asarray(point_cloud).T

        filtered_points = points[offset_filter]

        return filtered_points

    def _apply_slice_filter(self,
                            point_cloud,
                            ground_plane,
                            height_lo=0.2,
                            height_hi=2.0):
        """ Applies a slice filter to the point cloud

        Args:
            point_cloud: A point cloud in the shape (3, N)
            ground_plane: ground plane coefficients
            height_lo: (optional) lower height for slicing
            height_hi: (optional) upper height for slicing

        Returns:
            Points filtered with a slice filter in the shape (N, 3)
        """

        slice_filter = self.create_slice_filter(point_cloud, self.area_extents,
                                                ground_plane, height_lo,
                                                height_hi)

        # Transpose point cloud into N x 3 points
        points = np.asarray(point_cloud).T

        filtered_points = points[slice_filter]

        return filtered_points

    def create_sliced_voxel_grid_2d(self,
                                    sample_name,
                                    source,
                                    image_shape=None):
        """Generates a filtered 2D voxel grid from point cloud data

        Args:
            sample_name: image name to generate stereo pointcloud from
            source: point cloud source, e.g. 'lidar'
            image_shape: image dimensions [h, w], only required when
                source is 'lidar' or 'depth'

        Returns:
            voxel_grid_2d: 3d voxel grid from the given image
        """
        img_idx = int(sample_name)
        ground_plane = obj_utils.get_road_plane(img_idx,
                                                self.dataset.planes_dir)

        point_cloud = self.get_point_cloud(source,
                                           img_idx,
                                           image_shape=image_shape)
        filtered_points = self._apply_slice_filter(point_cloud, ground_plane)

        # Create Voxel Grid
        voxel_grid_2d = VoxelGrid2D()
        voxel_grid_2d.voxelize_2d(filtered_points,
                                  self.voxel_size,
                                  extents=self.area_extents,
                                  ground_plane=ground_plane,
                                  create_leaf_layout=True)

        return voxel_grid_2d

    def create_voxel_grid_3d(self,
                             sample_name,
                             ground_plane,
                             source='lidar',
                             filter_type='slice'):
        """Generates a filtered voxel grid from stereo data

            Args:
                sample_name: image name to generate stereo pointcloud from
                ground_plane: ground plane coefficients
                source: source of the pointcloud to create bev images
                    either "stereo" or "lidar"
                filter_type: type of point filter to use
                    'slice' for slice filtering (offset + ground)
                    'offset' for offset filtering only
                    'area' for area filtering only

           Returns:
               voxel_grid_3d: 3d voxel grid from the given image
        """
        img_idx = int(sample_name)

        points = self.get_point_cloud(source, img_idx)

        if filter_type == 'slice':
            filtered_points = self._apply_slice_filter(points, ground_plane)
        elif filter_type == 'offset':
            filtered_points = self._apply_offset_filter(points, ground_plane)
        elif filter_type == 'area':
            # A None ground plane will filter the points to the area extents
            filtered_points = self._apply_offset_filter(points, None)
        else:
            raise ValueError("Invalid filter_type {}, should be 'slice', "
                             "'offset', or 'area'".format(filter_type))

        # Create Voxel Grid
        voxel_grid_3d = VoxelGrid()
        voxel_grid_3d.voxelize(filtered_points,
                               self.voxel_size,
                               extents=self.area_extents)

        return voxel_grid_3d

    def filter_labels(self,
                      objects,
                      classes=None,
                      difficulty=None,
                      max_occlusion=None):
        """Filters ground truth labels based on class, difficulty, and
        maximum occlusion

        Args:
            objects: A list of ground truth instances of Object Label
            classes: (optional) classes to filter by, if None
                all classes are used
            difficulty: (optional) KITTI difficulty rating as integer
            max_occlusion: (optional) maximum occlusion to filter objects

        Returns:
            filtered object label list
        """
        if classes is None:
            classes = self.dataset.classes

        objects = np.asanyarray(objects)
        filter_mask = np.ones(len(objects), dtype=np.bool)

        for obj_idx in range(len(objects)):
            obj = objects[obj_idx]

            if filter_mask[obj_idx]:
                if not self._check_class(obj, classes):
                    filter_mask[obj_idx] = False
                    continue

            # Filter by difficulty (occlusion, truncation, and height)
            if difficulty is not None and \
                    not self._check_difficulty(obj, difficulty):
                filter_mask[obj_idx] = False
                continue

            if max_occlusion and \
                    obj.occlusion > max_occlusion:
                filter_mask[obj_idx] = False
                continue

        return objects[filter_mask]

    def _check_difficulty(self, obj, difficulty):
        """This filters an object by difficulty.
        Args:
            obj: An instance of ground-truth Object Label
            difficulty: An int defining the KITTI difficulty rate
        Returns: True or False depending on whether the object
            matches the difficulty criteria.
        """

        return ((obj.occlusion <= self.OCCLUSION[difficulty])
                and (obj.truncation <= self.TRUNCATION[difficulty])
                and (obj.y2 - obj.y1) >= self.HEIGHT[difficulty])

    def _check_class(self, obj, classes):
        """This filters an object by class.
        Args:
            obj: An instance of ground-truth Object Label
        Returns: True or False depending on whether the object
            matches the desired class.
        """
        return obj.type in classes
Example #7
0
def main():
    """
    Visualization of 3D grid anchor generation, showing 2D projections
        in BEV and image space, and a 3D display of the anchors
    """
    dataset_config = DatasetBuilder.copy_config(DatasetBuilder.KITTI_TRAIN)
    dataset_config.num_clusters[0] = 1
    dataset = DatasetBuilder.build_kitti_dataset(dataset_config)

    label_cluster_utils = LabelClusterUtils(dataset)
    clusters, _ = label_cluster_utils.get_clusters()

    # Options
    img_idx = 1
    # fake_clusters = np.array([[5, 4, 3], [6, 5, 4]])
    # fake_clusters = np.array([[3, 3, 3], [4, 4, 4]])

    fake_clusters = np.array([[4, 2, 3]])
    fake_anchor_stride = [5.0, 5.0]
    ground_plane = [0, -1, 0, 1.72]

    anchor_3d_generator = grid_anchor_3d_generator.GridAnchor3dGenerator()

    area_extents = np.array([[-40, 40], [-5, 5], [0, 70]])

    # Generate anchors for cars only
    start_time = time.time()
    anchor_boxes_3d = anchor_3d_generator.generate(
        area_3d=dataset.kitti_utils.area_extents,
        anchor_3d_sizes=fake_clusters,
        anchor_stride=fake_anchor_stride,
        ground_plane=ground_plane)
    all_anchors = box_3d_encoder.box_3d_to_anchor(anchor_boxes_3d)
    end_time = time.time()
    print("Anchors generated in {} s".format(end_time - start_time))

    # Project into bev
    bev_boxes, bev_normalized_boxes = \
        anchor_projector.project_to_bev(all_anchors, area_extents[[0, 2]])

    bev_fig, (bev_axes, bev_normalized_axes) = \
        plt.subplots(1, 2, figsize=(16, 7))
    bev_axes.set_xlim(0, 80)
    bev_axes.set_ylim(70, 0)
    bev_normalized_axes.set_xlim(0, 1.0)
    bev_normalized_axes.set_ylim(1, 0.0)

    plt.show(block=False)

    for box in bev_boxes:
        box_w = box[2] - box[0]
        box_h = box[3] - box[1]

        rect = patches.Rectangle((box[0], box[1]),
                                 box_w,
                                 box_h,
                                 linewidth=2,
                                 edgecolor='b',
                                 facecolor='none')

        bev_axes.add_patch(rect)

    for normalized_box in bev_normalized_boxes:
        box_w = normalized_box[2] - normalized_box[0]
        box_h = normalized_box[3] - normalized_box[1]

        rect = patches.Rectangle((normalized_box[0], normalized_box[1]),
                                 box_w,
                                 box_h,
                                 linewidth=2,
                                 edgecolor='b',
                                 facecolor='none')

        bev_normalized_axes.add_patch(rect)

    rgb_fig, rgb_2d_axes, rgb_3d_axes = \
        vis_utils.visualization(dataset.rgb_image_dir, img_idx)
    plt.show(block=False)

    image_path = dataset.get_rgb_image_path(dataset.sample_names[img_idx])
    image_shape = np.array(Image.open(image_path)).shape

    stereo_calib_p2 = calib_utils.read_calibration(dataset.calib_dir,
                                                   img_idx).p2

    start_time = time.time()
    rgb_boxes, rgb_normalized_boxes = \
        anchor_projector.project_to_image_space(all_anchors,
                                                stereo_calib_p2,
                                                image_shape)
    end_time = time.time()
    print("Anchors projected in {} s".format(end_time - start_time))

    # Read the stereo calibration matrix for visualization
    stereo_calib = calib_utils.read_calibration(dataset.calib_dir, 0)
    p = stereo_calib.p2

    # Overlay boxes on images

    for anchor_idx in range(len(anchor_boxes_3d)):
        anchor_box_3d = anchor_boxes_3d[anchor_idx]

        obj_label = box_3d_encoder.box_3d_to_object_label(anchor_box_3d)

        # Draw 3D boxes
        vis_utils.draw_box_3d(rgb_3d_axes, obj_label, p)

        # Draw 2D boxes
        rgb_box_2d = rgb_boxes[anchor_idx]

        box_x1 = rgb_box_2d[0]
        box_y1 = rgb_box_2d[1]
        box_w = rgb_box_2d[2] - box_x1
        box_h = rgb_box_2d[3] - box_y1

        rect = patches.Rectangle((box_x1, box_y1),
                                 box_w,
                                 box_h,
                                 linewidth=2,
                                 edgecolor='b',
                                 facecolor='none')

        rgb_2d_axes.add_patch(rect)

        if anchor_idx % 32 == 0:
            rgb_fig.canvas.draw()

    plt.show(block=True)
Example #8
0
def main():
    """
    Calculates clusters for each class

    Returns:
        all_clusters: list of clusters for each class
        all_std_devs: list of cluster standard deviations for each class
    """

    dataset = DatasetBuilder.build_kitti_dataset(DatasetBuilder.KITTI_TRAIN)

    # Calculate the remaining clusters
    # Load labels corresponding to the sample list for clustering
    sample_list = dataset.load_sample_names(dataset.cluster_split)
    all_dims = []

    num_samples = len(sample_list)
    for sample_idx in range(num_samples):

        sys.stdout.write("\rClustering labels {} / {}".format(
            sample_idx + 1, num_samples))
        sys.stdout.flush()

        sample_name = sample_list[sample_idx]
        img_idx = int(sample_name)

        obj_labels = obj_utils.read_labels(dataset.label_dir, img_idx)
        filtered_lwh = LabelClusterUtils._filter_labels_by_class(
                obj_labels, dataset.classes)

        if filtered_lwh[0]:
            all_dims.extend(filtered_lwh[0])

    all_dims = np.array(all_dims)
    print("\nFinished reading labels, clustering data...\n")

    # Print 3 decimal places
    np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})

    # Calculate average cluster
    k_means = KMeans(n_clusters=1,
                     random_state=0).fit(all_dims)

    cluster_centre = k_means.cluster_centers_[0]

    # Calculate std. dev
    std_dev = np.std(all_dims, axis=0)

    # Calculate 2 and 3 standard deviations below the mean
    two_sigma_length_lo = cluster_centre[0] - 2 * std_dev[0]
    three_sigma_length_lo = cluster_centre[0] - 3 * std_dev[0]

    # Remove all labels with length above two std dev
    # from the mean and re-cluster
    small_mask_2 = all_dims[:, 0] < two_sigma_length_lo
    small_dims_2 = all_dims[small_mask_2]

    small_mask_3 = all_dims[:, 0] < three_sigma_length_lo
    small_dims_3 = all_dims[small_mask_3]

    small_k_means_2 = KMeans(n_clusters=1, random_state=0).fit(small_dims_2)
    small_k_means_3 = KMeans(n_clusters=1, random_state=0).fit(small_dims_3)
    small_std_dev_2 = np.std(small_dims_2, axis=0)
    small_std_dev_3 = np.std(small_dims_3, axis=0)

    print('small_k_means_2:', small_k_means_2.cluster_centers_)
    print('small_k_means_3:', small_k_means_3.cluster_centers_)
    print('small_std_dev_2:', small_std_dev_2)
    print('small_std_dev_3:', small_std_dev_3)

    # Calculate 2 and 3 standard deviations above the mean
    two_sigma_length_hi = cluster_centre[0] + 2 * std_dev[0]
    three_sigma_length_hi = cluster_centre[0] + 3 * std_dev[0]

    # Remove all labels with length above two std dev
    # from the mean and re-cluster
    large_mask_2 = all_dims[:, 0] > two_sigma_length_hi
    large_dims_2 = all_dims[large_mask_2]

    large_mask_3 = all_dims[:, 0] > three_sigma_length_hi
    large_dims_3 = all_dims[large_mask_3]

    large_k_means_2 = KMeans(n_clusters=1, random_state=0).fit(large_dims_2)
    large_k_means_3 = KMeans(n_clusters=1, random_state=0).fit(large_dims_3)

    large_std_dev_2 = np.std(large_dims_2, axis=0)
    large_std_dev_3 = np.std(large_dims_3, axis=0)

    print('large_k_means_2:', large_k_means_2.cluster_centers_)
    print('large_k_means_3:', large_k_means_3.cluster_centers_)
    print('large_std_dev_2:', large_std_dev_2)
    print('large_std_dev_3:', large_std_dev_3)