def __init__(self, **kwargs: Any) -> None:
     super().__init__(
         should_validate=False,
         architecture="Basic",
         feature_channels=[2] * 8,
         crop_size=(64, 64, 64),
         image_channels=["ct", "heart"],
         # Test with multiple channels, even though the "heart" is clearly nonsense
         ground_truth_ids=fg_classes,
         ground_truth_ids_display_names=fg_classes,
         colours=[(255, 255, 255)] * len(fg_classes),
         fill_holes=[False] * len(fg_classes),
         roi_interpreted_types=["ORGAN"] * len(fg_classes),
         mask_id="heart",
         norm_method=PhotometricNormalizationMethod.CtWindow,
         level=50,
         window=200,
         class_weights=equally_weighted_classes(fg_classes),
         num_dataload_workers=1,
         train_batch_size=8,
         start_epoch=0,
         num_epochs=2,
         recovery_checkpoint_save_interval=1,
         use_mixed_precision=True,
         azure_dataset_id=AZURE_DATASET_ID,
         # Use an LR scheduler with a pronounced and clearly visible decay, to be able to easily see if that
         # is applied correctly in run recovery.
         l_rate=1e-4,
         l_rate_scheduler=LRSchedulerType.Step,
         l_rate_step_step_size=1,
         l_rate_step_gamma=0.9,
         # Necessary to avoid https://github.com/pytorch/pytorch/issues/45324
         max_num_gpus=2,
     )
     self.add_and_validate(kwargs)
Exemplo n.º 2
0
    def __init__(self, **kwargs: Any) -> None:
        fg_classes = ["spinalcord", "lung_r", "lung_l", "heart", "esophagus"]
        fg_display_names = [
            "SpinalCord", "Lung_R", "Lung_L", "Heart", "Esophagus"
        ]

        super().__init__(
            should_validate=False,
            # Set as UNet3D only because this does not shrink patches in the forward pass.
            architecture=ModelArchitectureConfig.UNet3D,
            azure_dataset_id=AZURE_DATASET_ID,
            crop_size=(64, 224, 224),
            num_dataload_workers=1,
            # Disable monitoring so that we can use VS Code remote debugging
            monitoring_interval_seconds=0,
            image_channels=["ct"],
            ground_truth_ids=fg_classes,
            ground_truth_ids_display_names=fg_display_names,
            colours=generate_random_colours_list(RANDOM_COLOUR_GENERATOR,
                                                 len(fg_classes)),
            fill_holes=[False] * len(fg_classes),
            roi_interpreted_types=["ORGAN"] * len(fg_classes),
            inference_batch_size=1,
            class_weights=equally_weighted_classes(fg_classes,
                                                   background_weight=0.02),
            feature_channels=[1],
            start_epoch=0,
            num_epochs=1,
            # Necessary to avoid https://github.com/pytorch/pytorch/issues/45324
            max_num_gpus=1,
        )
        self.add_and_validate(kwargs)
Exemplo n.º 3
0
def test_equally_weighted_classes(num_fg_classes: int, background_weight: Optional[float],
                                  expected: List[float]) -> None:
    classes = [""] * num_fg_classes
    actual = equally_weighted_classes(classes, background_weight)
    assert isinstance(actual, list)
    assert len(actual) == num_fg_classes + 1
    assert sum(actual) == pytest.approx(1.0)
    assert actual == pytest.approx(expected)
Exemplo n.º 4
0
 def __init__(self, **kwargs: Any) -> None:
     fg_classes = [
         "external", "femur_r", "femur_l", "rectum", "prostate", "bladder",
         "seminalvesicles"
     ]
     fg_display_names = [
         "External", "Femur_R", "Femur_L", "Rectum", "Prostate", "Bladder",
         "SeminalVesicles"
     ]
     colors = [(255, 0, 0)] * len(fg_display_names)
     fill_holes = [True, True, True, True, True, False, True]
     super().__init__(
         should_validate=False,
         adam_betas=(0.9, 0.999),
         architecture="UNet3D",
         class_weights=equally_weighted_classes(fg_classes,
                                                background_weight=0.02),
         crop_size=(64, 224, 224),
         feature_channels=[32],
         ground_truth_ids=fg_classes,
         ground_truth_ids_display_names=[
             f"zz_{name}" for name in fg_display_names
         ],
         colours=colors,
         fill_holes=fill_holes,
         image_channels=["ct"],
         inference_batch_size=1,
         inference_stride_size=(64, 256, 256),
         kernel_size=3,
         l_rate=1e-3,
         min_l_rate=1e-5,
         l_rate_polynomial_gamma=0.9,
         largest_connected_component_foreground_classes=[
             name for name in fg_classes if name != "seminalvesicles"
         ],
         level=50,
         momentum=0.9,
         monitoring_interval_seconds=0,
         norm_method=PhotometricNormalizationMethod.CtWindow,
         num_dataload_workers=8,
         num_epochs=120,
         opt_eps=1e-4,
         optimizer_type=OptimizerType.Adam,
         save_step_epochs=20,
         start_epoch=0,
         test_crop_size=(128, 512, 512),
         test_diff_epochs=1,
         test_start_epoch=120,
         test_step_epochs=1,
         train_batch_size=8,
         use_mixed_precision=True,
         use_model_parallel=True,
         weight_decay=1e-4,
         window=600,
         posterior_smoothing_mm=(2.0, 2.0, 3.0),
         save_start_epoch=100,
     )
     self.add_and_validate(kwargs)
Exemplo n.º 5
0
 def __init__(self, **kwargs: Any) -> None:
     fg_classes = ["spinalcord", "lung_r", "lung_l", "heart", "esophagus"]
     fg_display_names = [
         "SpinalCord", "Lung_R", "Lung_L", "Heart", "Esophagus"
     ]
     super().__init__(
         architecture="UNet3D",
         feature_channels=[32],
         kernel_size=3,
         azure_dataset_id=AZURE_DATASET_ID,
         crop_size=(64, 224, 224),
         test_crop_size=(128, 512, 512),
         image_channels=["ct"],
         ground_truth_ids=fg_classes,
         ground_truth_ids_display_names=fg_display_names,
         colours=[(255, 255, 255)] * len(fg_classes),
         fill_holes=[False] * len(fg_classes),
         largest_connected_component_foreground_classes=[
             "lung_r", "lung_l", "heart"
         ],
         num_dataload_workers=8,
         norm_method=PhotometricNormalizationMethod.CtWindow,
         level=40,
         window=400,
         class_weights=equally_weighted_classes(fg_classes,
                                                background_weight=0.02),
         train_batch_size=8,
         inference_batch_size=1,
         inference_stride_size=(64, 256, 256),
         start_epoch=0,
         num_epochs=140,
         l_rate=1e-3,
         min_l_rate=1e-5,
         l_rate_polynomial_gamma=0.9,
         optimizer_type=OptimizerType.Adam,
         opt_eps=1e-4,
         adam_betas=(0.9, 0.999),
         momentum=0.9,
         weight_decay=1e-4,
         save_start_epoch=100,
         save_step_epochs=20,
         test_start_epoch=140,
         use_mixed_precision=True,
         use_model_parallel=True,
         monitoring_interval_seconds=0,
         test_diff_epochs=1,
         test_step_epochs=1,
         loss_type=SegmentationLoss.Mixture,
         mixture_loss_components=[
             MixtureLossComponent(0.5, SegmentationLoss.Focal, 0.2),
             MixtureLossComponent(0.5, SegmentationLoss.SoftDice, 0.1)
         ],
     )
     self.add_and_validate(kwargs)
Exemplo n.º 6
0
 def __init__(self, **kwargs: Any) -> None:
     """
     Creates a new instance of the class.
     :param kwargs: Additional arguments that will be passed through to the SegmentationModelBase constructor.
     """
     ground_truth_ids = fg_classes
     if "ground_truth_ids_display_names" in kwargs:
         ground_truth_ids_display_names = kwargs.pop(
             "ground_truth_ids_display_names")
     else:
         logging.info('Using default ground_truth_ids_display_names')
         ground_truth_ids_display_names = [
             f"zz_{name}" for name in fg_display_names
         ]
     if "colours" in kwargs:
         colours = kwargs.pop("colours")
     else:
         logging.info('Using default colours')
         colours = [(255, 0, 0)] * len(ground_truth_ids)
     if "fill_holes" in kwargs:
         fill_holes = kwargs.pop("fill_holes")
     else:
         logging.info('Using default fill_holes')
         fill_holes = [True, True, True, True, True, False, True]
     if "class_weights" in kwargs:
         class_weights = kwargs.pop("class_weights")
     else:
         logging.info('Using default class_weights')
         class_weights = equally_weighted_classes(ground_truth_ids,
                                                  background_weight=0.02)
     if "largest_connected_component_foreground_classes" in kwargs:
         largest_connected_component_foreground_classes = kwargs.pop(
             "largest_connected_component_foreground_classes")
     else:
         logging.info(
             'Using default largest_connected_component_foreground_classes')
         largest_connected_component_foreground_classes = [
             name for name in ground_truth_ids if name != "seminalvesicles"
         ]
     super().__init__(
         ground_truth_ids=ground_truth_ids,
         ground_truth_ids_display_names=ground_truth_ids_display_names,
         colours=colours,
         fill_holes=fill_holes,
         class_weights=class_weights,
         largest_connected_component_foreground_classes=
         largest_connected_component_foreground_classes,
         **kwargs)
    def __init__(self, **kwargs: Any) -> None:
        fg_classes = ["region", "region_1"]
        super().__init__(
            # Data definition - in this section we define where to load the dataset from
            local_dataset=full_ml_test_data_path(),

            # Model definition - in this section we define what model to use and some related configurations
            architecture="UNet3D",
            feature_channels=[4],
            crop_size=(64, 64, 64),
            image_channels=["channel1", "channel2"],
            ground_truth_ids=fg_classes,
            class_weights=equally_weighted_classes(fg_classes,
                                                   background_weight=0.02),
            mask_id="mask",

            # Model training and testing - in this section we define configurations pertaining to the model
            # training loop (ie: batch size, how many epochs to train, number of epochs to save)
            # and testing (ie: how many epochs to test)
            use_gpu=False,
            num_dataload_workers=0,
            train_batch_size=2,
            start_epoch=0,
            num_epochs=2,
            save_start_epoch=1,
            save_step_epochs=1,
            test_start_epoch=2,
            test_diff_epochs=1,
            test_step_epochs=1,
            use_mixed_precision=True,

            # Pre-processing - in this section we define how to normalize our inputs, in this case we are doing
            # CT Level and Window based normalization.
            norm_method=PhotometricNormalizationMethod.CtWindow,
            level=50,
            window=200,

            # Post-processing - in this section we define our post processing configurations, in this case
            # we are filling holes in the generated segmentation masks for all of the foreground classes.
            fill_holes=[True] * len(fg_classes),

            # Output - in this section we define settings that determine how our output looks like in this case
            # we define the structure names and colours to use.
            ground_truth_ids_display_names=fg_classes,
            colours=generate_random_colours_list(Random(5), len(fg_classes)),
        )
        self.add_and_validate(kwargs)
def test_visualize_patch_sampling_2d(
        test_output_dirs: TestOutputDirectories) -> None:
    """
    Tests if patch sampling works for 2D images.
    :param test_output_dirs:
    """
    set_random_seed(0)
    shape = (1, 20, 30)
    foreground_classes = ["fg"]
    class_weights = equally_weighted_classes(foreground_classes)
    config = SegmentationModelBase(should_validate=False,
                                   crop_size=(1, 5, 10),
                                   class_weights=class_weights)
    image = np.random.rand(1, *shape).astype(np.float32) * 1000
    mask = np.ones(shape)
    labels = np.zeros((len(class_weights), ) + shape)
    labels[1, 0, 8:12, 5:25] = 1
    labels[0] = 1 - labels[1]
    output_folder = Path(test_output_dirs.root_dir)
    image_header = None
    sample = Sample(image=image,
                    mask=mask,
                    labels=labels,
                    metadata=PatientMetadata(patient_id='123',
                                             image_header=image_header))
    heatmap = visualize_random_crops(sample,
                                     config,
                                     output_folder=output_folder)
    expected_folder = full_ml_test_data_path("patch_sampling")
    expected_heatmap = expected_folder / "sampling_2d.npy"
    # To update the stored results, uncomment this line:
    # np.save(str(expected_heatmap), heatmap)
    assert np.allclose(heatmap, np.load(
        str(expected_heatmap))), "Patch sampling created a different heatmap."
    assert len(list(output_folder.rglob("*.nii.gz"))) == 0
    assert len(list(output_folder.rglob("*.png"))) == 1
    actual_file = output_folder / "123_sampled_patches.png"
    assert_file_exists(actual_file)
    expected = expected_folder / "sampling_2d.png"
    # To update the stored results, uncomment this line:
    # expected.write_bytes(actual_file.read_bytes())
    if not is_running_on_azure():
        # When running on the Azure build agents, it appears that the bounding box of the images
        # is slightly different than on local runs, even with equal dpi settings.
        # It says: Image sizes don't match: actual (685, 469), expected (618, 424)
        # Not able to figure out how to make the run results consistent, hence disable in cloud runs.
        assert_binary_files_match(actual_file, expected)
    def __init__(self, **kwargs: Any) -> None:
        fg_classes = ["external", "femur_r", "femur_l", "rectum", "prostate", "bladder", "seminalvesicles"]
        super().__init__(
            should_validate=False,
            architecture="UNet3D",
            feature_channels=[32],
            kernel_size=3,
            crop_size=(64, 224, 224),
            test_crop_size=(128, 512, 512),
            image_channels=["ct"],
            ground_truth_ids=fg_classes,
            largest_connected_component_foreground_classes=["external", "femur_r", "femur_l", "rectum", "prostate",
                                                            "bladder"],

            colours=[(255, 255, 255)] * len(self.fg_classes),
            fill_holes=[False] * len(self.fg_classes),
            ground_truth_ids_display_names=fg_classes,
            num_dataload_workers=8,
            norm_method=PhotometricNormalizationMethod.CtWindow,
            level=50,
            window=600,
            class_weights=equally_weighted_classes(fg_classes, background_weight=0.02),
            train_batch_size=8,
            inference_batch_size=1,
            inference_stride_size=(64, 256, 256),
            start_epoch=0,
            num_epochs=120,
            l_rate=1e-3,
            min_l_rate=1e-5,
            l_rate_polynomial_gamma=0.9,
            optimizer_type=OptimizerType.Adam,
            opt_eps=1e-4,
            adam_betas=(0.9, 0.999),
            momentum=0.9,
            weight_decay=1e-4,
            save_start_epoch=20,
            save_step_epochs=20,
            test_start_epoch=120,
            use_mixed_precision=True,
            use_model_parallel=True,
            monitoring_interval_seconds=0,
            test_diff_epochs=1,
            test_step_epochs=1
        )
        self.add_and_validate(kwargs)
Exemplo n.º 10
0
 def __init__(self, **kwargs: Any) -> None:
     fg_classes = ["tumour_mass"]
     super().__init__(
         should_validate=False,
         architecture=ModelArchitectureConfig.UNet3D,
         feature_channels=[32],
         crop_size=(64, 192, 160),
         kernel_size=3,
         test_crop_size=(256, 320,
                         320),  # This encloses all images in the dataset.
         inference_stride_size=(128, 160, 160),
         inference_batch_size=1,
         image_channels=["mr"],
         ground_truth_ids=fg_classes,
         ground_truth_ids_display_names=fg_classes,
         colours=[(255, 255, 255)] * len(fg_classes),
         fill_holes=[False] * len(fg_classes),
         num_dataload_workers=8,
         mask_id=None,
         norm_method=PhotometricNormalizationMethod.MriWindow,
         trim_percentiles=(1, 99),
         sharpen=2.5,
         tail=[1.0],
         class_weights=equally_weighted_classes(fg_classes),
         train_batch_size=8,
         start_epoch=0,
         num_epochs=200,
         l_rate=1e-3,
         l_rate_polynomial_gamma=0.9,
         optimizer_type=OptimizerType.Adam,
         opt_eps=1e-4,
         adam_betas=(0.9, 0.999),
         momentum=0.9,
         weight_decay=1e-4,
         save_start_epoch=50,
         save_step_epochs=10,
         test_start_epoch=50,
         test_diff_epochs=20,
         test_step_epochs=10,
         use_mixed_precision=True,
         use_model_parallel=True,
     )
     self.add_and_validate(kwargs)
Exemplo n.º 11
0
def test_equally_weighted_classes_fails(
        num_fg_clases: int, background_weight: Optional[float]) -> None:
    classes = [""] * num_fg_clases
    with pytest.raises(ValueError):
        equally_weighted_classes(classes, background_weight)
 def __init__(self, num_structures: int = 0, **kwargs: Any) -> None:
     """
     :param num_structures: number of structures from STRUCTURE_LIST to predict (default: all structures)
     :param kwargs: other args from subclass
     """
     # Number of training epochs
     num_epochs = 120
     # Number of structures to predict; if positive but less than the length of STRUCTURE_LIST, the relevant prefix
     # of STRUCTURE_LIST will be predicted.
     if num_structures <= 0 or num_structures > len(STRUCTURE_LIST):
         num_structures = len(STRUCTURE_LIST)
     ground_truth_ids = STRUCTURE_LIST[:num_structures]
     colours = COLOURS[:num_structures]
     fill_holes = FILL_HOLES[:num_structures]
     ground_truth_ids_display_names = [f"zz_{x}" for x in ground_truth_ids]
     # The amount of GPU memory required increases with both the number of structures and the
     # number of feature channels. The following is a sensible default to avoid out-of-memory,
     # but you can override is by passing in another (singleton list) value for feature_channels
     # from a subclass.
     num_feature_channels = 32 if num_structures <= 20 else 26
     bg_weight = 0.02 if len(ground_truth_ids) > 1 else 0.25
     # In case of vertical overlap between brainstem and spinal_cord, we separate them
     # by converting brainstem voxels to cord, as the latter is clinically more sensitive.
     # We do the same to separate SPC and MPC; in this case, the direction of change is unimportant,
     # so we choose SPC-to-MPC arbitrarily.
     slice_exclusion_rules = []
     summed_probability_rules = []
     if "brainstem" in ground_truth_ids and "spinal_cord" in ground_truth_ids:
         slice_exclusion_rules.append(
             SliceExclusionRule("brainstem", "spinal_cord", False))
         if "external" in ground_truth_ids:
             summed_probability_rules.append(
                 SummedProbabilityRule("spinal_cord", "brainstem",
                                       "external"))
     if "spc_muscle" in ground_truth_ids and "mpc_muscle" in ground_truth_ids:
         slice_exclusion_rules.append(
             SliceExclusionRule("spc_muscle", "mpc_muscle", False))
         if "external" in ground_truth_ids:
             summed_probability_rules.append(
                 SummedProbabilityRule("mpc_muscle", "spc_muscle",
                                       "external"))
     if "optic_chiasm" in ground_truth_ids and "pituitary_gland" in ground_truth_ids:
         slice_exclusion_rules.append(
             SliceExclusionRule("optic_chiasm", "pituitary_gland", True))
         if "external" in ground_truth_ids:
             summed_probability_rules.append(
                 SummedProbabilityRule("optic_chiasm", "pituitary_gland",
                                       "external"))
     super().__init__(
         should_validate=False,  # we'll validate after kwargs are added
         num_gpus=4,
         num_epochs=num_epochs,
         save_start_epoch=num_epochs,
         save_step_epoch=num_epochs,
         architecture="UNet3D",
         kernel_size=3,
         train_batch_size=4,
         inference_batch_size=1,
         feature_channels=[num_feature_channels],
         crop_size=(96, 288, 288),
         test_crop_size=(144, 512, 512),
         inference_stride_size=(72, 256, 256),
         image_channels=["ct"],
         norm_method=PhotometricNormalizationMethod.CtWindow,
         level=50,
         window=600,
         start_epoch=0,
         l_rate=1e-3,
         min_l_rate=1e-5,
         l_rate_polynomial_gamma=0.9,
         optimizer_type=OptimizerType.Adam,
         opt_eps=1e-4,
         adam_betas=(0.9, 0.999),
         momentum=0.9,
         test_diff_epochs=1,
         test_step_epochs=1,
         use_mixed_precision=True,
         use_model_parallel=True,
         monitoring_interval_seconds=0,
         num_dataload_workers=4,
         loss_type=SegmentationLoss.Mixture,
         mixture_loss_components=[
             MixtureLossComponent(0.5, SegmentationLoss.Focal, 0.2),
             MixtureLossComponent(0.5, SegmentationLoss.SoftDice, 0.1)
         ],
         ground_truth_ids=ground_truth_ids,
         ground_truth_ids_display_names=ground_truth_ids_display_names,
         largest_connected_component_foreground_classes=ground_truth_ids,
         colours=colours,
         fill_holes=fill_holes,
         class_weights=equally_weighted_classes(
             ground_truth_ids, background_weight=bg_weight),
         slice_exclusion_rules=slice_exclusion_rules,
         summed_probability_rules=summed_probability_rules,
     )
     self.add_and_validate(kwargs)
def test_visualize_patch_sampling(test_output_dirs: TestOutputDirectories,
                                  labels_to_boundary: bool) -> None:
    """
    Tests if patch sampling and producing diagnostic images works as expected.
    :param test_output_dirs:
    :param labels_to_boundary: If true, the ground truth labels are placed close to the image boundary, so that
    crops have to be adjusted inwards. If false, ground truth labels are all far from the image boundaries.
    """
    set_random_seed(0)
    shape = (10, 30, 30)
    foreground_classes = ["fg"]
    class_weights = equally_weighted_classes(foreground_classes)
    config = SegmentationModelBase(should_validate=False,
                                   crop_size=(2, 10, 10),
                                   class_weights=class_weights)
    image = np.random.rand(1, *shape).astype(np.float32) * 1000
    mask = np.ones(shape)
    labels = np.zeros((len(class_weights), ) + shape)
    if labels_to_boundary:
        # Generate foreground labels in such a way that a patch centered around a foreground pixel would
        # reach outside of the image.
        labels[1, 4:8, 3:27, 3:27] = 1
    else:
        labels[1, 4:8, 15:18, 15:18] = 1
    labels[0] = 1 - labels[1]
    output_folder = Path(test_output_dirs.root_dir)
    image_header = get_unit_image_header()
    sample = Sample(image=image,
                    mask=mask,
                    labels=labels,
                    metadata=PatientMetadata(patient_id='123',
                                             image_header=image_header))
    expected_folder = full_ml_test_data_path("patch_sampling")
    heatmap = visualize_random_crops(sample,
                                     config,
                                     output_folder=output_folder)
    expected_heatmap = expected_folder / ("sampled_to_boundary.npy"
                                          if labels_to_boundary else
                                          "sampled_center.npy")
    # To update the stored results, uncomment this line:
    # np.save(str(expected_heatmap), heatmap)
    assert np.allclose(heatmap, np.load(
        str(expected_heatmap))), "Patch sampling created a different heatmap."
    f1 = output_folder / "123_ct.nii.gz"
    assert_file_exists(f1)
    f2 = output_folder / "123_sampled_patches.nii.gz"
    assert_file_exists(f2)
    thumbnails = [
        "123_sampled_patches_dim0.png",
        "123_sampled_patches_dim1.png",
        "123_sampled_patches_dim2.png",
    ]
    for f in thumbnails:
        assert_file_exists(output_folder / f)

    expected = expected_folder / ("sampled_to_boundary.nii.gz"
                                  if labels_to_boundary else
                                  "sampled_center.nii.gz")
    # To update test results:
    # shutil.copy(str(f2), str(expected))
    expected_image = io_util.load_nifti_image(expected)
    actual_image = io_util.load_nifti_image(f2)
    np.allclose(expected_image.image, actual_image.image)
    if labels_to_boundary:
        for f in thumbnails:
            # Uncomment this line to update test results
            # (expected_folder / f).write_bytes((output_folder / f).read_bytes())
            if not is_running_on_azure():
                # When running on the Azure build agents, it appears that the bounding box of the images
                # is slightly different than on local runs, even with equal dpi settings.
                # Not able to figure out how to make the run results consistent, hence disable in cloud runs.
                assert_binary_files_match(output_folder / f,
                                          expected_folder / f)
Exemplo n.º 14
0
 def __init__(self,
              ground_truth_ids: List[str],
              ground_truth_ids_display_names: Optional[List[str]] = None,
              colours: Optional[List[TupleInt3]] = None,
              fill_holes: Optional[List[bool]] = None,
              class_weights: Optional[List[float]] = None,
              largest_connected_component_foreground_classes: Optional[
                  List[str]] = None,
              **kwargs: Any) -> None:
     """
     Creates a new instance of the class.
     :param ground_truth_ids: List of ground truth ids.
     :param ground_truth_ids_display_names: Optional list of ground truth id display names. If
     present then must be of the same length as ground_truth_ids.
     :param colours: Optional list of colours. If
     present then must be of the same length as ground_truth_ids.
     :param fill_holes: Optional list of fill hole flags. If
     present then must be of the same length as ground_truth_ids.
     :param class_weights: Optional list of class weights. If
     present then must be of the same length as ground_truth_ids + 1.
     :param kwargs: Additional arguments that will be passed through to the SegmentationModelBase constructor.
     """
     ground_truth_ids_display_names = ground_truth_ids_display_names or [
         f"zz_{name}" for name in ground_truth_ids
     ]
     colours = colours or [(255, 0, 0)] * len(ground_truth_ids)
     fill_holes = fill_holes or [True] * len(ground_truth_ids)
     class_weights = class_weights or equally_weighted_classes(
         ground_truth_ids, background_weight=0.02)
     largest_connected_component_foreground_classes = largest_connected_component_foreground_classes or \
                                                      ground_truth_ids
     super().__init__(
         should_validate=False,
         adam_betas=(0.9, 0.999),
         architecture="UNet3D",
         class_weights=class_weights,
         crop_size=(64, 224, 224),
         feature_channels=[32],
         ground_truth_ids=ground_truth_ids,
         ground_truth_ids_display_names=ground_truth_ids_display_names,
         colours=colours,
         fill_holes=fill_holes,
         image_channels=["ct"],
         inference_batch_size=1,
         inference_stride_size=(64, 256, 256),
         kernel_size=3,
         l_rate=1e-3,
         min_l_rate=1e-5,
         l_rate_polynomial_gamma=0.9,
         largest_connected_component_foreground_classes=
         largest_connected_component_foreground_classes,
         level=50,
         momentum=0.9,
         monitoring_interval_seconds=0,
         norm_method=PhotometricNormalizationMethod.CtWindow,
         num_dataload_workers=8,
         num_epochs=120,
         opt_eps=1e-4,
         optimizer_type=OptimizerType.Adam,
         save_step_epochs=20,
         start_epoch=0,
         test_crop_size=(128, 512, 512),
         test_diff_epochs=1,
         test_start_epoch=120,
         test_step_epochs=1,
         train_batch_size=8,
         use_mixed_precision=True,
         use_model_parallel=True,
         weight_decay=1e-4,
         window=600,
         posterior_smoothing_mm=(2.0, 2.0, 3.0),
         save_start_epoch=100,
     )
     self.add_and_validate(kwargs)
 def __init__(self,
              ground_truth_ids: List[str],
              ground_truth_ids_display_names: Optional[List[str]] = None,
              colours: Optional[List[TupleInt3]] = None,
              fill_holes: Optional[List[bool]] = None,
              class_weights: Optional[List[float]] = None,
              slice_exclusion_rules: Optional[List[SliceExclusionRule]] = None,
              summed_probability_rules: Optional[List[SummedProbabilityRule]] = None,
              num_feature_channels: Optional[int] = None,
              **kwargs: Any) -> None:
     """
     Creates a new instance of the class.
     :param ground_truth_ids: List of ground truth ids.
     :param ground_truth_ids_display_names: Optional list of ground truth id display names. If
     present then must be of the same length as ground_truth_ids.
     :param colours: Optional list of colours. If
     present then must be of the same length as ground_truth_ids.
     :param fill_holes: Optional list of fill hole flags. If
     present then must be of the same length as ground_truth_ids.
     :param class_weights: Optional list of class weights. If
     present then must be of the same length as ground_truth_ids + 1.
     :param slice_exclusion_rules: Optional list of SliceExclusionRules.
     :param summed_probability_rules: Optional list of SummedProbabilityRule.
     :param num_feature_channels: Optional number of feature channels.
     :param kwargs: Additional arguments that will be passed through to the SegmentationModelBase constructor.
     """
     # Number of training epochs
     num_epochs = 120
     num_structures = len(ground_truth_ids)
     colours = colours or generate_random_colours_list(RANDOM_COLOUR_GENERATOR, num_structures)
     fill_holes = fill_holes or [True] * num_structures
     ground_truth_ids_display_names = ground_truth_ids_display_names or [f"zz_{x}" for x in ground_truth_ids]
     # The amount of GPU memory required increases with both the number of structures and the
     # number of feature channels. The following is a sensible default to avoid out-of-memory,
     # but you can override is by passing in another (singleton list) value for feature_channels
     # from a subclass.
     num_feature_channels = num_feature_channels or (32 if num_structures <= 20 else 26)
     bg_weight = 0.02 if len(ground_truth_ids) > 1 else 0.25
     class_weights = class_weights or equally_weighted_classes(ground_truth_ids, background_weight=bg_weight)
     # In case of vertical overlap between brainstem and spinal_cord, we separate them
     # by converting brainstem voxels to cord, as the latter is clinically more sensitive.
     # We do the same to separate SPC and MPC; in this case, the direction of change is unimportant,
     # so we choose SPC-to-MPC arbitrarily.
     slice_exclusion_rules = slice_exclusion_rules or []
     summed_probability_rules = summed_probability_rules or []
     super().__init__(
         should_validate=False,  # we'll validate after kwargs are added
         num_epochs=num_epochs,
         recovery_checkpoint_save_interval=10,
         architecture="UNet3D",
         kernel_size=3,
         train_batch_size=1,
         inference_batch_size=1,
         feature_channels=[num_feature_channels],
         crop_size=(96, 288, 288),
         test_crop_size=(144, 512, 512),
         inference_stride_size=(72, 256, 256),
         image_channels=["ct"],
         norm_method=PhotometricNormalizationMethod.CtWindow,
         level=50,
         window=600,
         start_epoch=0,
         l_rate=1e-3,
         min_l_rate=1e-5,
         l_rate_polynomial_gamma=0.9,
         optimizer_type=OptimizerType.Adam,
         opt_eps=1e-4,
         adam_betas=(0.9, 0.999),
         momentum=0.9,
         use_mixed_precision=True,
         use_model_parallel=True,
         monitoring_interval_seconds=0,
         num_dataload_workers=2,
         loss_type=SegmentationLoss.Mixture,
         mixture_loss_components=[MixtureLossComponent(0.5, SegmentationLoss.Focal, 0.2),
                                  MixtureLossComponent(0.5, SegmentationLoss.SoftDice, 0.1)],
         ground_truth_ids=ground_truth_ids,
         ground_truth_ids_display_names=ground_truth_ids_display_names,
         largest_connected_component_foreground_classes=ground_truth_ids,
         colours=colours,
         fill_holes=fill_holes,
         class_weights=class_weights,
         slice_exclusion_rules=slice_exclusion_rules,
         summed_probability_rules=summed_probability_rules,
     )
     self.add_and_validate(kwargs)
Exemplo n.º 16
0
    def __init__(self, num_structures: Optional[int] = None, **kwargs: Any) -> None:
        """
        Creates a new instance of the class.
        :param num_structures: number of structures from STRUCTURE_LIST to predict (default: all structures)
        :param kwargs: Additional arguments that will be passed through to the SegmentationModelBase constructor.
        """
        # Number of structures to predict; if positive but less than the length of STRUCTURE_LIST, the relevant prefix
        # of STRUCTURE_LIST will be predicted.
        if (num_structures is not None) and \
                (num_structures <= 0 or num_structures > len(STRUCTURE_LIST)):
            raise ValueError(f"num structures must be between 0 and {len(STRUCTURE_LIST)}")
        if num_structures is None:
            logging.info(f'Setting num_structures to: {len(STRUCTURE_LIST)}')
            num_structures = len(STRUCTURE_LIST)
        ground_truth_ids = STRUCTURE_LIST[:num_structures]
        if "ground_truth_ids_display_names" in kwargs:
            ground_truth_ids_display_names = kwargs.pop("ground_truth_ids_display_names")
        else:
            logging.info('Using default ground_truth_ids_display_names')
            ground_truth_ids_display_names = [f"zz_{x}" for x in ground_truth_ids]
        if "colours" in kwargs:
            colours = kwargs.pop("colours")
        else:
            logging.info('Using default colours')
            colours = COLOURS[:num_structures]
        if "fill_holes" in kwargs:
            fill_holes = kwargs.pop("fill_holes")
        else:
            logging.info('Using default fill_holes')
            fill_holes = [True] * num_structures
        # The amount of GPU memory required increases with both the number of structures and the
        # number of feature channels. The following is a sensible default to avoid out-of-memory,
        # but you can override is by passing in another (singleton list) value for feature_channels
        # from a subclass.
        if "num_feature_channels" in kwargs:
            num_feature_channels = kwargs.pop("num_feature_channels")
        else:
            logging.info('Using default num_feature_channels')
            num_feature_channels = 32 if num_structures <= 20 else 26
        bg_weight = 0.02 if len(ground_truth_ids) > 1 else 0.25
        if "class_weights" in kwargs:
            class_weights = kwargs.pop("class_weights")
        else:
            logging.info('Using default class_weights')
            class_weights = equally_weighted_classes(ground_truth_ids, background_weight=bg_weight)
        # In case of vertical overlap between brainstem and spinal_cord, we separate them
        # by converting brainstem voxels to cord, as the latter is clinically more sensitive.
        # We do the same to separate SPC and MPC; in this case, the direction of change is unimportant,
        # so we choose SPC-to-MPC arbitrarily.
        if "slice_exclusion_rules" in kwargs:
            slice_exclusion_rules = kwargs.pop("slice_exclusion_rules")
        else:
            logging.info('Using default slice_exclusion_rules')
            slice_exclusion_rules = []
            if "brainstem" in ground_truth_ids and "spinal_cord" in ground_truth_ids:
                slice_exclusion_rules.append(SliceExclusionRule("brainstem", "spinal_cord", False))
            if "spc_muscle" in ground_truth_ids and "mpc_muscle" in ground_truth_ids:
                slice_exclusion_rules.append(SliceExclusionRule("spc_muscle", "mpc_muscle", False))
            if "optic_chiasm" in ground_truth_ids and "pituitary_gland" in ground_truth_ids:
                slice_exclusion_rules.append(SliceExclusionRule("optic_chiasm", "pituitary_gland", True))

        if "summed_probability_rules" in kwargs:
            summed_probability_rules = kwargs.pop("summed_probability_rules")
        else:
            logging.info('Using default summed_probability_rules')
            summed_probability_rules = []
            if "brainstem" in ground_truth_ids and "spinal_cord" in ground_truth_ids and \
                    "external" in ground_truth_ids:
                summed_probability_rules.append(SummedProbabilityRule("spinal_cord", "brainstem", "external"))
            if "spc_muscle" in ground_truth_ids and "mpc_muscle" in ground_truth_ids and \
                    "external" in ground_truth_ids:
                summed_probability_rules.append(SummedProbabilityRule("mpc_muscle", "spc_muscle", "external"))
            if "optic_chiasm" in ground_truth_ids and "pituitary_gland" in ground_truth_ids and \
                    "external" in ground_truth_ids:
                summed_probability_rules.append(SummedProbabilityRule("optic_chiasm", "pituitary_gland", "external"))
        super().__init__(
            ground_truth_ids=ground_truth_ids,
            ground_truth_ids_display_names=ground_truth_ids_display_names,
            colours=colours,
            fill_holes=fill_holes,
            class_weights=class_weights,
            slice_exclusion_rules=slice_exclusion_rules,
            summed_probability_rules=summed_probability_rules,
            num_feature_channels=num_feature_channels,
            **kwargs)