class TestSegmentationDataset: def setup(self): self.dataset = SegmentationDataset( rows=1, columns=1, segmentation_type=SegmentationType.BINARY) self.setup_dummy_segment(self.dataset) def setup_dummy_segment(self, dataset: pydicom.Dataset): ds = pydicom.Dataset() ds.SegmentNumber = 1 dataset.SegmentSequence.append(ds) def generate_dummy_source_image(self): ds = pydicom.Dataset() ds.SOPClassUID = '1.2.840.10008.5.1.4.1.1.2' # CT Image Storage ds.SOPInstanceUID = pydicom.uid.generate_uid() ds.SeriesInstanceUID = pydicom.uid.generate_uid() return ds def test_dataset_is_writable(self): with tempfile.NamedTemporaryFile() as ofile: self.dataset.save_as(ofile.name) def test_dataset_has_valid_file_meta(self): pydicom.dataset.validate_file_meta(self.dataset.file_meta) def test_file_meta_has_information_group_length_computed(self): assert 'FileMetaInformationGroupLength' in self.dataset.file_meta assert self.dataset.file_meta.FileMetaInformationGroupLength > 0 def test_mandatory_sop_common(self): assert self.dataset.SOPClassUID == '1.2.840.10008.5.1.4.1.1.66.4' assert 'SOPInstanceUID' in self.dataset def test_mandatory_enhanced_equipment_elements(self): """http://dicom.nema.org/medical/dicom/current/output/chtml/part03/sect_C.7.5.2.html#table_C.7-8b""" assert self.dataset.Manufacturer == 'pydicom-seg' assert self.dataset.ManufacturerModelName == 'https://github.com/razorx89/pydicom-seg' assert self.dataset.DeviceSerialNumber == '0' assert self.dataset.SoftwareVersions == __version__ def test_mandatory_frame_of_reference_elements(self): """http://dicom.nema.org/medical/dicom/current/output/chtml/part03/sect_C.7.4.html#table_C.7-6""" assert 'FrameOfReferenceUID' in self.dataset def test_mandatory_gernal_series_elements(self): """http://dicom.nema.org/medical/dicom/current/output/chtml/part03/sect_C.7.3.html#table_C.7-5a""" assert self.dataset.Modality == 'SEG' assert 'SeriesInstanceUID' in self.dataset def test_mandatory_segmentation_series_elements(self): """http://dicom.nema.org/medical/dicom/current/output/chtml/part03/sect_C.8.20.html#table_C.8.20-1""" assert self.dataset.Modality == 'SEG' assert self.dataset.SeriesNumber def test_mandatory_image_pixel_elements(self): """http://dicom.nema.org/medical/dicom/current/output/chtml/part03/sect_C.7.6.3.html#table_C.7-11a""" assert self.dataset.SamplesPerPixel >= 1 assert self.dataset.PhotometricInterpretation in [ 'MONOCHROME1', 'MONOCHROME2' ] assert 'Rows' in self.dataset assert 'Columns' in self.dataset assert self.dataset.BitsAllocated in [1, 8, 16] assert 0 < self.dataset.BitsStored <= self.dataset.BitsAllocated assert self.dataset.HighBit == self.dataset.BitsStored - 1 assert self.dataset.PixelRepresentation in [0, 1] def test_mandatory_and_common_segmentation_image_elements(self): """http://dicom.nema.org/medical/dicom/current/output/chtml/part03/sect_C.8.20.2.html#table_C.8.20-2""" assert 'ImageType' in self.dataset assert all([ a == b for a, b in zip(self.dataset.ImageType, ['DERIVED', 'PRIMARY']) ]) assert self.dataset.InstanceNumber assert self.dataset.ContentLabel == 'SEGMENTATION' assert 'ContentCreatorName' in self.dataset assert 'ContentDescription' in self.dataset assert self.dataset.SamplesPerPixel == 1 assert self.dataset.PhotometricInterpretation == 'MONOCHROME2' assert self.dataset.PixelRepresentation == 0 assert self.dataset.LossyImageCompression == '00' assert 'SegmentSequence' in self.dataset def test_mandatory_binary_segmentation_image_elements(self): """http://dicom.nema.org/medical/dicom/current/output/chtml/part03/sect_C.8.20.2.html#table_C.8.20-2""" assert self.dataset.BitsAllocated == 1 assert self.dataset.BitsStored == 1 assert self.dataset.HighBit == 0 assert self.dataset.SegmentationType == 'BINARY' @pytest.mark.parametrize('fractional_type', ['PROBABILITY', 'OCCUPANCY']) def test_mandatory_fractional_segmentation_image_elements( self, fractional_type): """http://dicom.nema.org/medical/dicom/current/output/chtml/part03/sect_C.8.20.2.html#table_C.8.20-2""" dataset = SegmentationDataset( rows=1, columns=1, segmentation_type=SegmentationType.FRACTIONAL, segmentation_fractional_type=SegmentationFractionalType( fractional_type)) assert dataset.BitsAllocated == 8 assert dataset.BitsStored == 8 assert dataset.HighBit == 7 # Little Endian assert dataset.SegmentationType == 'FRACTIONAL' assert dataset.SegmentationFractionalType == fractional_type assert dataset.MaximumFractionalValue == 255 def test_mandatory_multi_frame_functional_groups_elements(self): """http://dicom.nema.org/medical/dicom/current/output/chtml/part03/sect_C.7.6.16.html#table_C.7.6.16-1""" assert 'SharedFunctionalGroupsSequence' in self.dataset assert len(self.dataset.SharedFunctionalGroupsSequence) == 1 assert 'PerFrameFunctionalGroupsSequence' in self.dataset assert self.dataset.NumberOfFrames == 0 assert self.dataset.InstanceNumber assert 'ContentDate' in self.dataset assert 'ContentTime' in self.dataset def test_timestamps_exist(self): assert 'InstanceCreationDate' in self.dataset assert 'InstanceCreationTime' in self.dataset assert self.dataset.InstanceCreationDate == self.dataset.SeriesDate assert self.dataset.InstanceCreationTime == self.dataset.SeriesTime assert self.dataset.InstanceCreationDate == self.dataset.ContentDate assert self.dataset.InstanceCreationTime == self.dataset.ContentTime def test_exception_on_invalid_image_dimensions(self): with pytest.raises(ValueError, match='.*must be larger than zero'): SegmentationDataset(rows=0, columns=0, segmentation_type=SegmentationType.BINARY) @pytest.mark.parametrize('max_fractional_value', [-1, 0, 256]) def test_exception_on_invalid_max_fractional_value(self, max_fractional_value): with pytest.raises(ValueError, match='Invalid maximum fractional value.*'): SegmentationDataset( rows=1, columns=1, segmentation_type=SegmentationType.FRACTIONAL, max_fractional_value=max_fractional_value, ) def test_exception_when_adding_frame_with_wrong_rank(self): with pytest.raises(ValueError, match='.*expecting 2D image'): self.dataset.add_frame(np.zeros((1, 1, 1), dtype=np.uint8), 1) def test_exception_when_adding_frame_with_wrong_shape(self): with pytest.raises(ValueError, match='.*expecting \\d+x\\d+ images'): self.dataset.add_frame(np.zeros((2, 1), dtype=np.uint8), 1) @pytest.mark.parametrize('segmentation_type,dtype', [(SegmentationType.BINARY, np.float32), (SegmentationType.FRACTIONAL, np.uint8)]) def test_exception_when_adding_frame_with_wrong_data_type( self, segmentation_type, dtype): dataset = SegmentationDataset(rows=1, columns=1, segmentation_type=segmentation_type) with pytest.raises(ValueError, match='.*requires.*?data type'): dataset.add_frame(np.zeros((1, 1), dtype=dtype), 1) def test_adding_frame_increases_number_of_frames(self): old_count = self.dataset.NumberOfFrames print(type(old_count)) self.dataset.add_frame(np.zeros((1, 1), dtype=np.uint8), 1) assert self.dataset.NumberOfFrames == old_count + 1 def test_adding_binary_frame_modifies_pixel_data(self): dataset = SegmentationDataset( rows=2, columns=2, segmentation_type=SegmentationType.BINARY) self.setup_dummy_segment(dataset) assert len(dataset.PixelData) == 0 dataset.add_frame(np.zeros((2, 2), dtype=np.uint8), 1) assert len(dataset.PixelData) == 1 for _ in range(2): dataset.add_frame(np.ones((2, 2), dtype=np.uint8), 1) assert len(dataset.PixelData) == 2 def test_adding_fractional_frame_modifies_pixel_data(self): dataset = SegmentationDataset( rows=2, columns=2, segmentation_type=SegmentationType.FRACTIONAL) self.setup_dummy_segment(dataset) assert len(dataset.PixelData) == 0 dataset.add_frame(np.zeros((2, 2), dtype=np.float32), 1) assert len(dataset.PixelData) == 4 for _ in range(2): dataset.add_frame(np.ones((2, 2), dtype=np.float32), 1) assert len(dataset.PixelData) == 12 def test_adding_frame_with_reference_creates_referenced_series_sequence( self): assert 'ReferencedSeriesSequence' not in self.dataset dummy = self.generate_dummy_source_image() self.dataset.add_frame(np.zeros((1, 1), np.uint8), 1, [dummy]) assert 'ReferencedSeriesSequence' in self.dataset series_sequence = self.dataset.ReferencedSeriesSequence assert len(series_sequence) == 1 assert series_sequence[0].SeriesInstanceUID == dummy.SeriesInstanceUID assert 'ReferencedInstanceSequence' in series_sequence[0] instance_sequence = series_sequence[0].ReferencedInstanceSequence assert len(instance_sequence) == 1 assert instance_sequence[0].ReferencedSOPClassUID == dummy.SOPClassUID assert instance_sequence[ 0].ReferencedSOPInstanceUID == dummy.SOPInstanceUID def test_adding_frames_with_different_references_from_same_series(self): dummy1 = self.generate_dummy_source_image() dummy2 = self.generate_dummy_source_image() dummy2.SeriesInstanceUID = dummy1.SeriesInstanceUID self.dataset.add_frame(np.zeros((1, 1), np.uint8), 1, [dummy1]) self.dataset.add_frame(np.zeros((1, 1), np.uint8), 1, [dummy2]) series_sequence = self.dataset.ReferencedSeriesSequence assert len(series_sequence) == 1 assert series_sequence[0].SeriesInstanceUID == dummy1.SeriesInstanceUID instance_sequence = series_sequence[0].ReferencedInstanceSequence assert len(instance_sequence) == 2 assert instance_sequence[ 0].ReferencedSOPInstanceUID == dummy1.SOPInstanceUID assert instance_sequence[ 1].ReferencedSOPInstanceUID == dummy2.SOPInstanceUID def test_adding_frames_with_different_references_from_different_series( self): dummies = [self.generate_dummy_source_image() for _ in range(2)] self.dataset.add_frame(np.zeros((1, 1), np.uint8), 1, [dummies[0]]) self.dataset.add_frame(np.zeros((1, 1), np.uint8), 1, [dummies[1]]) series_sequence = self.dataset.ReferencedSeriesSequence assert len(series_sequence) == 2 assert series_sequence[0].SeriesInstanceUID == dummies[ 0].SeriesInstanceUID assert series_sequence[1].SeriesInstanceUID == dummies[ 1].SeriesInstanceUID instance_sequence = series_sequence[0].ReferencedInstanceSequence assert len(instance_sequence) == 1 assert instance_sequence[0].ReferencedSOPInstanceUID == dummies[ 0].SOPInstanceUID instance_sequence = series_sequence[1].ReferencedInstanceSequence assert len(instance_sequence) == 1 assert instance_sequence[0].ReferencedSOPInstanceUID == dummies[ 1].SOPInstanceUID def test_adding_instance_reference_multiple_times(self): dummy = self.generate_dummy_source_image() item_added = self.dataset.add_instance_reference(dummy) assert item_added item_added = self.dataset.add_instance_reference(dummy) assert not item_added series_sequence = self.dataset.ReferencedSeriesSequence assert len(series_sequence) == 1 assert series_sequence[0].SeriesInstanceUID == dummy.SeriesInstanceUID assert len(series_sequence[0].ReferencedInstanceSequence) == 1 def test_adding_frame_increases_count_of_per_functional_groups_sequence( self): assert len(self.dataset.PerFrameFunctionalGroupsSequence) == 0 self.dataset.add_frame(np.zeros((1, 1), np.uint8), 1) assert len(self.dataset.PerFrameFunctionalGroupsSequence) == 1 def test_adding_frame_adds_derivation_image_sequence_to_per_frame_functional_group_item( self): frame_item = self.dataset.add_frame(np.zeros((1, 1), np.uint8), 1) assert 'DerivationImageSequence' in frame_item dummy = self.generate_dummy_source_image() frame_item = self.dataset.add_frame(np.zeros((1, 1), np.uint8), 1, [dummy]) assert 'SourceImageSequence' in frame_item.DerivationImageSequence[0] assert len( frame_item.DerivationImageSequence[0].SourceImageSequence) == 1 def test_adding_frame_adds_referenced_segment_to_per_frame_functional_group_item( self): frame_item = self.dataset.add_frame(np.zeros((1, 1), np.uint8), 1) assert 'SegmentIdentificationSequence' in frame_item assert len(frame_item.SegmentIdentificationSequence) == 1 segment_id_item = frame_item.SegmentIdentificationSequence[0] assert 'ReferencedSegmentNumber' in segment_id_item assert segment_id_item.ReferencedSegmentNumber == 1 def test_exception_on_adding_frame_with_non_existing_segment(self): with pytest.raises(IndexError, match='Segment not found.*'): self.dataset.add_frame(np.zeros((1, 1), np.uint8), 2) def test_add_dimension_organization(self): assert 'DimensionOrganizationSequence' not in self.dataset assert 'DimensionIndexSequence' not in self.dataset seq = DimensionOrganizationSequence() seq.add_dimension('ReferencedSegmentNumber', 'SegmentIdentificationSequence') seq.add_dimension('ImagePositionPatient', 'PlanePositionSequence') self.dataset.add_dimension_organization(seq) assert len(self.dataset.DimensionOrganizationSequence) == 1 assert len(self.dataset.DimensionIndexSequence) == 2 assert self.dataset.DimensionIndexSequence[ 0].DimensionDescriptionLabel == 'ReferencedSegmentNumber' assert self.dataset.DimensionIndexSequence[ 1].DimensionDescriptionLabel == 'ImagePositionPatient' def test_add_dimension_organization_duplicate(self): seq = DimensionOrganizationSequence() seq.add_dimension('ReferencedSegmentNumber', 'SegmentIdentificationSequence') seq.add_dimension('ImagePositionPatient', 'PlanePositionSequence') self.dataset.add_dimension_organization(seq) with pytest.raises(ValueError, match='Dimension organization with UID.*'): self.dataset.add_dimension_organization(seq) def test_add_multiple_dimension_organizations(self): for _ in range(2): seq = DimensionOrganizationSequence() seq.add_dimension('ReferencedSegmentNumber', 'SegmentIdentificationSequence') seq.add_dimension('ImagePositionPatient', 'PlanePositionSequence') self.dataset.add_dimension_organization(seq) assert len(self.dataset.DimensionOrganizationSequence) == 2 assert len(self.dataset.DimensionIndexSequence) == 4
def write(self, segmentation: sitk.Image, source_images: List[pydicom.Dataset]) -> pydicom.Dataset: """Writes a DICOM-SEG dataset from a segmentation image and the corresponding DICOM source images. Args: segmentation: A `SimpleITK.Image` with integer labels and a single component per spatial location. source_images: A list of `pydicom.Dataset` which are the source images for the segmentation image. Returns: A `pydicom.Dataset` instance with all necessary information and meta information for writing the dataset to disk. """ if segmentation.GetDimension() != 3: raise ValueError("Only 3D segmentation data is supported") if segmentation.GetNumberOfComponentsPerPixel() > 1: raise ValueError("Multi-class segmentations can only be " "represented with a single component per voxel") if segmentation.GetPixelID() not in [ sitk.sitkUInt8, sitk.sitkUInt16, sitk.sitkUInt32, sitk.sitkUInt64, ]: raise ValueError("Unsigned integer data type required") # TODO Add further checks if source images are from the same series slice_to_source_images = self._map_source_images_to_segmentation( segmentation, source_images) # Compute unique labels and their respective bounding boxes label_statistics_filter = sitk.LabelStatisticsImageFilter() label_statistics_filter.Execute(segmentation, segmentation) unique_labels = set( [x for x in label_statistics_filter.GetLabels() if x != 0]) if len(unique_labels) == 0: raise ValueError("Segmentation does not contain any labels") # Check if all present labels where declared in the DICOM template declared_segments = set( [x.SegmentNumber for x in self._template.SegmentSequence]) missing_declarations = unique_labels.difference(declared_segments) if missing_declarations: missing_segment_numbers = ", ".join( [str(x) for x in missing_declarations]) message = ( f"Skipping segment(s) {missing_segment_numbers}, since their " "declaration is missing in the DICOM template") if not self._skip_missing_segment: raise ValueError(message) logger.warning(message) labels_to_process = unique_labels.intersection(declared_segments) if not labels_to_process: raise ValueError("No segments found for encoding as DICOM-SEG") # Compute bounding boxes for each present label and optionally restrict # the volume to serialize to the joined maximum extent bboxs = { x: label_statistics_filter.GetBoundingBox(x) for x in labels_to_process } if self._inplane_cropping: min_x, min_y, _ = np.min([x[::2] for x in bboxs.values()], axis=0).tolist() max_x, max_y, _ = ( np.max([x[1::2] for x in bboxs.values()], axis=0) + 1).tolist() logger.info( "Serializing cropped image planes starting at coordinates " f"({min_x}, {min_y}) with size ({max_x - min_x}, {max_y - min_y})" ) else: min_x, min_y = 0, 0 max_x, max_y = segmentation.GetWidth(), segmentation.GetHeight() logger.info( f"Serializing image planes at full size ({max_x}, {max_y})") # Create target dataset for storing serialized data result = SegmentationDataset( reference_dicom=source_images[0] if source_images else None, rows=max_y - min_y, columns=max_x - min_x, segmentation_type=SegmentationType.BINARY, ) dimension_organization = DimensionOrganizationSequence() dimension_organization.add_dimension("ReferencedSegmentNumber", "SegmentIdentificationSequence") dimension_organization.add_dimension("ImagePositionPatient", "PlanePositionSequence") result.add_dimension_organization(dimension_organization) writer_utils.copy_segmentation_template( target=result, template=self._template, segments=labels_to_process, skip_missing_segment=self._skip_missing_segment, ) writer_utils.set_shared_functional_groups_sequence( target=result, segmentation=segmentation) # FIX - Use ImageOrientationPatient value from DICOM source rather than the segmentation result.SharedFunctionalGroupsSequence[0].PlaneOrientationSequence[ 0].ImageOrientationPatient = source_images[ 0].ImageOrientationPatient buffer = sitk.GetArrayFromImage(segmentation) for segment in labels_to_process: logger.info(f"Processing segment {segment}") if self._skip_empty_slices: bbox = bboxs[segment] min_z, max_z = bbox[4], bbox[5] + 1 else: min_z, max_z = 0, segmentation.GetDepth() logger.info( "Total number of slices that will be processed for segment " f"{segment} is {max_z - min_z} (inclusive from {min_z} to {max_z})" ) skipped_slices = [] for slice_idx in range(min_z, max_z): frame_index = (min_x, min_y, slice_idx) frame_position = segmentation.TransformIndexToPhysicalPoint( frame_index) frame_data = np.equal( buffer[slice_idx, min_y:max_y, min_x:max_x], segment) if self._skip_empty_slices and not frame_data.any(): skipped_slices.append(slice_idx) continue frame_fg_item = result.add_frame( data=frame_data.astype(np.uint8), referenced_segment=segment, referenced_images=slice_to_source_images[slice_idx], ) frame_fg_item.FrameContentSequence = [pydicom.Dataset()] frame_fg_item.FrameContentSequence[0].DimensionIndexValues = [ segment, # Segment number slice_idx - min_z + 1, # Slice index within cropped volume ] frame_fg_item.PlanePositionSequence = [pydicom.Dataset()] frame_fg_item.PlanePositionSequence[0].ImagePositionPatient = [ f"{x:e}" for x in frame_position ] if skipped_slices: logger.info(f"Skipped empty slices for segment {segment}: " f'{", ".join([str(x) for x in skipped_slices])}') # Encode all frames into a bytearray if self._inplane_cropping or self._skip_empty_slices: num_encoded_bytes = len(result.PixelData) max_encoded_bytes = (segmentation.GetWidth() * segmentation.GetHeight() * segmentation.GetDepth() * len(result.SegmentSequence) // 8) savings = (1 - num_encoded_bytes / max_encoded_bytes) * 100 logger.info( f"Optimized frame data length is {num_encoded_bytes:,}B " f"instead of {max_encoded_bytes:,}B (saved {savings:.2f}%)") result.SegmentsOverlap = "NO" return result