Ejemplo n.º 1
0
 def test_segmentation(self):
     """Test if batch size makes a difference on the segmentation output"""
     scan = self.SCAN_TYPE(dicom_path=self.dicom_dirpath)
     tissue = FemoralCartilage()
     tissue.find_weights(SEGMENTATION_WEIGHTS_FOLDER)
     dims = scan.get_dimensions()
     input_shape = (dims[0], dims[1], 1)
     model = get_model(SEGMENTATION_MODEL,
                       input_shape=input_shape,
                       weights_path=tissue.weights_file_path)
     scan.segment(model, tissue)
Ejemplo n.º 2
0
def handle_segmentation(vargin, scan: ScanSequence, tissue: Tissue):
    segment_weights_path = vargin[SEGMENTATION_WEIGHTS_DIR_KEY][0]
    tissue.find_weights(segment_weights_path)

    # Load model
    dims = scan.get_dimensions()
    input_shape = (dims[0], dims[1], 1)
    model = get_model(vargin[SEGMENTATION_MODEL_KEY],
                      input_shape=input_shape,
                      weights_path=tissue.weights_file_path)
    model.batch_size = vargin[SEGMENTATION_BATCH_SIZE_KEY]

    return model
Ejemplo n.º 3
0
    def test_segmentation(self):
        """Test automatic segmentation
           Expected: NotImplementedError
        """
        scan = self.SCAN_TYPE(dicom_path=self.dicom_dirpath)
        tissue = FemoralCartilage()
        tissue.find_weights(SEGMENTATION_WEIGHTS_FOLDER)
        dims = scan.get_dimensions()
        input_shape = (dims[0], dims[1], 1)
        model = get_model(SEGMENTATION_MODEL,
                          input_shape=input_shape,
                          weights_path=tissue.weights_file_path)

        # automatic segmentation currently not implemented
        with self.assertRaises(NotImplementedError):
            scan.segment(model, tissue)
Ejemplo n.º 4
0
    def test_segmentation_multiclass(self):
        """Test support for multiclass segmentation."""
        scan = self.SCAN_TYPE.from_dicom(self.dicom_dirpath,
                                         num_workers=util.num_workers())
        tissue = FemoralCartilage()
        tissue.find_weights(SEGMENTATION_WEIGHTS_FOLDER),
        dims = scan.get_dimensions()
        input_shape = (dims[0], dims[1], 1)
        model = get_model(SEGMENTATION_MODEL,
                          input_shape=input_shape,
                          weights_path=tissue.weights_file_path)
        scan.segment(model, tissue, use_rss=True)

        # This should call __del__ in KerasSegModel
        model = None
        K.clear_session()
Ejemplo n.º 5
0
def handle_segmentation(vargin, scan: ScanSequence, tissue: Tissue):
    if not vargin[SEGMENTATION_MODEL_KEY] and not vargin[SEGMENTATION_CONFIG_KEY]:
        raise ValueError(
            "Either `--{}` or `--{}` must be specified".format(
                SEGMENTATION_MODEL_KEY, SEGMENTATION_CONFIG_KEY
            )
        )

    segment_weights_path = vargin[SEGMENTATION_WEIGHTS_DIR_KEY][0]
    if isinstance(tissue, Sequence):
        weights = [t.find_weights(segment_weights_path) for t in tissue]
        assert all(weights_file == weights[0] for weights_file in weights)
        weights_path = weights[0]
    else:
        weights_path = tissue.find_weights(segment_weights_path)

    # Load model
    dims = scan.get_dimensions()
    # TODO: Input shape should be determined by combination of model + scan.
    # Currently fixed in 2D plane
    input_shape = (dims[0], dims[1], 1)
    if vargin[SEGMENTATION_MODEL_KEY]:
        # Use built-in model
        model = get_model(
            vargin[SEGMENTATION_MODEL_KEY], input_shape=input_shape, weights_path=weights_path
        )
    else:
        # Use config
        model = model_from_config(
            vargin[SEGMENTATION_CONFIG_KEY],
            weights_dir=segment_weights_path,
            input_shape=input_shape,
        )
    model.batch_size = vargin[SEGMENTATION_BATCH_SIZE_KEY]

    return model