def connected_components(images): """Labels the connected components in a batch of images. A component is a set of pixels in a single input image, which are all adjacent and all have the same non-zero value. The components using a squared connectivity of one (all True entries are joined with their neighbors above, below, left, and right). Components across all images have consecutive ids 1 through n. Components are labeled according to the first pixel of the component appearing in row-major order (lexicographic order by image_index_in_batch, row, col). Zero entries all have an output id of 0. This op is equivalent with `scipy.ndimage.measurements.label` on a 2D array with the default structuring element (which is the connectivity used here). Args: images: A 2D (H, W) or 3D (N, H, W) Tensor of boolean image(s). Returns: Components with the same shape as `images`. False entries in `images` have value 0, and all True entries map to a component id > 0. Raises: TypeError: if `images` is not 2D or 3D. """ with ops.name_scope("connected_components"): image_or_images = ops.convert_to_tensor(images, name="images") if len(image_or_images.get_shape()) == 2: images = image_or_images[None, :, :] elif len(image_or_images.get_shape()) == 3: images = image_or_images else: raise TypeError( "images should have rank 2 (HW) or 3 (NHW). Static shape is %s" % image_or_images.get_shape()) components = gen_image_ops.image_connected_components(images) # TODO(ringwalt): Component id renaming should be done in the op, to avoid # constructing multiple additional large tensors. components_flat = array_ops.reshape(components, [-1]) unique_ids, id_index = array_ops.unique(components_flat) id_is_zero = array_ops.where(math_ops.equal(unique_ids, 0))[:, 0] # Map each nonzero id to consecutive values. nonzero_consecutive_ids = math_ops.range( array_ops.shape(unique_ids)[0] - array_ops.shape(id_is_zero)[0]) + 1 def no_zero(): # No need to insert a zero into the ids. return nonzero_consecutive_ids def has_zero(): # Insert a zero in the consecutive ids where zero appears in unique_ids. # id_is_zero has length 1. zero_id_ind = math_ops.to_int32(id_is_zero[0]) ids_before = nonzero_consecutive_ids[:zero_id_ind] ids_after = nonzero_consecutive_ids[zero_id_ind:] return array_ops.concat([ids_before, [0], ids_after], axis=0) new_ids = control_flow_ops.cond( math_ops.equal(array_ops.shape(id_is_zero)[0], 0), no_zero, has_zero) components = array_ops.reshape( array_ops.gather(new_ids, id_index), array_ops.shape(components)) if len(image_or_images.get_shape()) == 2: return components[0, :, :] else: return components