예제 #1
0
 def test_raise_exception(self, _, args, tensor, expected_error):
     with self.assertRaises(expected_error):
         converter = KeepLargestConnectedComponent(**args)
         if torch.cuda.is_available():
             _ = converter(tensor.clone().cuda())
         else:
             _ = converter(tensor.clone())
예제 #2
0
 def test_correct_results(self, _, args, tensor, expected):
     converter = KeepLargestConnectedComponent(**args)
     if torch.cuda.is_available():
         result = converter(tensor.clone().cuda())
         assert torch.allclose(result, expected.cuda())
     else:
         result = converter(tensor.clone())
         assert torch.allclose(result, expected)
예제 #3
0
 def test_raise_exception(self, _, args, input_image, expected_error):
     with self.assertRaises(expected_error):
         converter = KeepLargestConnectedComponent(**args)
         if isinstance(input_image,
                       torch.Tensor) and torch.cuda.is_available():
             _ = converter(clone(input_image).cuda())
         else:
             _ = converter(clone(input_image).clone())
예제 #4
0
    def test_correct_results(self, _, args, input_image, expected):
        converter = KeepLargestConnectedComponent(**args)
        if isinstance(input_image, torch.Tensor) and torch.cuda.is_available():
            result = converter(clone(input_image).cuda())

        else:
            result = converter(clone(input_image))
        assert_allclose(result, expected)
    def test_correct_results_before_after_onehot(self, _, args, input_image,
                                                 expected):
        """
        From torch==1.7, torch.argmax changes its mechanism that if there are multiple maximal values then the
        indices of the first maximal value are returned (before this version, the indices of the last maximal value
        are returned).
        Therefore, we can may use of this changes to convert the onehotted labels into un-onehot format directly
        and then check if the result stays the same.

        """
        converter = KeepLargestConnectedComponent(**args)
        result = converter(deepcopy(input_image))

        if "is_onehot" in args:
            args["is_onehot"] = not args["is_onehot"]
        # if not onehotted, onehot it and make sure result stays the same
        if input_image.shape[0] == 1:
            img = to_onehot(input_image)
            result2 = KeepLargestConnectedComponent(**args)(img)
            result2 = result2.argmax(0)[None]
            assert_allclose(result, result2)
        # if onehotted, un-onehot and check result stays the same
        else:
            img = input_image.argmax(0)[None]
            result2 = KeepLargestConnectedComponent(**args)(img)
            assert_allclose(result.argmax(0)[None], result2)
def predict(file_name, model_path='', _params=params, output_name=None):
    print('Segmenting ' + file_name + ' ...')
    start = time.time()

    # Create test sample as tensor batch
    test_transforms = get_test_transforms(_params['image_shape'])
    test_file = [{"image": file_name}]
    test_batch_image = test_transforms(test_file)[0]["image"].unsqueeze(0)

    # Load model and inference
    # https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-across-devices
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = torch.load(model_path, map_location=device)
    model.eval()
    seg = model(test_batch_image.to(device))

    # Postprocessing: keep largest seg component
    seg = torch.argmax(seg, dim=1, keepdim=True).detach().cpu()
    keeplargest = KeepLargestConnectedComponent(applied_labels=1)
    seg = keeplargest(seg)[0]

    # Resize output to original image size
    # Need to bring to canonical because so will seg be
    img = nib.load(file_name)
    img_canon = nib.as_closest_canonical(img)
    crop = CenterSpatialCrop(img_canon.shape)
    seg = crop(seg)[0]

    # Save output seg in canonical orientation
    seg1 = nib.Nifti1Image(seg.numpy(), img_canon.affine, img_canon.header)
    nib.save(seg1, output_name)

    # Change output seg orientation to original image orientation
    seg_file = [{"image": output_name}]
    seg_transforms = get_seg_transforms(
        end_seg_axcodes=nib.aff2axcodes(img.affine))
    seg1 = seg_transforms(seg_file)

    # Save output seg with same orientation as original image orientation
    seg1 = nib.Nifti1Image(seg1[0]["image"][0], img.affine, img.header)
    nib.save(seg1, output_name)

    print('Segmentation saved to ' + output_name)
    end = time.time()
    print('√ (time taken: ', round(end - start, ndigits=4), 'seconds)')
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                roi_size = (96, 96, 96)
                sw_batch_size = 4
                val_outputs = sliding_window_inference(val_inputs, roi_size,
                                                       sw_batch_size, model)
                #val_outputs = post_pred(val_outputs)
                #val_labels = post_label(val_labels)
                val_outputs = [
                    post_pred(i) for i in decollate_batch(val_outputs)
                ]
                val_labels = [
                    post_label(i) for i in decollate_batch(val_labels)
                ]
                largest = KeepLargestConnectedComponent(applied_labels=[1])
                #value = compute_meandice(
                #    y_pred=val_outputs,
                #    y=val_labels,
                #    include_background=False,
                #)
                value = dice_metric(y_pred=val_outputs, y=val_labels)
                metric_count += len(value[0])
                metric_sum += value[0].sum().item()

            # aggregate the final mean dice result
            metric = dice_metric.aggregate().item()
            # reset the status for next validation round
            dice_metric.reset()
            #metric = metric_sum / metric_count
            metric_values.append(metric)
 def test_correct_results(self, _, args, input_image, expected):
     converter = KeepLargestConnectedComponent(**args)
     result = converter(input_image)
     assert_allclose(result, expected, type_test=False)