def test_train(self, input_args):
        input_dims = input_args.get("dims", (2, 3, 8))
        device = (torch.device("cuda") if
                  (input_args.get("device") == "cuda"
                   and torch.cuda.is_available()) else torch.device("cpu:0"))

        base = torch.ones(*input_dims).to(device)
        gt = torch.tensor(input_args["gt"], requires_grad=False).to(device)
        g_type = input_args["type"]
        lr = input_args.get("lr", 0.1)
        init_sigma = input_args.get("init", 1.0)

        # static filter to generate a target
        spatial_dims = len(base.shape) - 2
        filtering = GaussianFilter(spatial_dims=spatial_dims,
                                   sigma=gt,
                                   approx=g_type,
                                   requires_grad=False)
        filtering.to(device)
        target = filtering(base)
        self.assertFalse(filtering.sigma[0].requires_grad)

        # build trainable
        init_sigma = torch.tensor(init_sigma).to(device)
        trainable = GaussianFilter(spatial_dims=spatial_dims,
                                   sigma=init_sigma,
                                   approx=g_type,
                                   requires_grad=True)
        trainable.to(device)
        self.assertTrue(trainable.sigma[0].requires_grad)

        # train
        optimizer = torch.optim.Adam(trainable.parameters(), lr=lr)
        for s in range(1000):
            optimizer.zero_grad()
            pred = trainable(base)
            loss = torch.square(pred - target).mean()
            loss.backward()
            if (s + 1) % 50 == 0:
                var = list(trainable.parameters())[0]
                print(f"step {s} loss {loss}")
                print(var, var.grad)
            if loss.item() < 1e-7:
                break
            optimizer.step()
        # check the result
        print(s, gt)
        for idx, s in enumerate(trainable.sigma):
            np.testing.assert_allclose(
                s.cpu().item(),
                gt.cpu() if len(gt.shape) == 0 else gt[idx].cpu().item(),
                rtol=1e-2)
Esempio n. 2
0
class ProbNMS(Transform):
    """
    Performs probability based non-maximum suppression (NMS) on the probabilities map via
    iteratively selecting the coordinate with highest probability and then move it as well
    as its surrounding values. The remove range is determined by the parameter `box_size`.
    If multiple coordinates have the same highest probability, only one of them will be
    selected.

    Args:
        spatial_dims: number of spatial dimensions of the input probabilities map.
            Defaults to 2.
        sigma: the standard deviation for gaussian filter.
            It could be a single value, or `spatial_dims` number of values. Defaults to 0.0.
        prob_threshold: the probability threshold, the function will stop searching if
            the highest probability is no larger than the threshold. The value should be
            no less than 0.0. Defaults to 0.5.
        box_size: the box size (in pixel) to be removed around the the pixel with the maximum probability.
            It can be an integer that defines the size of a square or cube,
            or a list containing different values for each dimensions. Defaults to 48.

    Return:
        a list of selected lists, where inner lists contain probability and coordinates.
        For example, for 3D input, the inner lists are in the form of [probability, x, y, z].

    Raises:
        ValueError: When ``prob_threshold`` is less than 0.0.
        ValueError: When ``box_size`` is a list or tuple, and its length is not equal to `spatial_dims`.
        ValueError: When ``box_size`` has a less than 1 value.

    """

    backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

    def __init__(
        self,
        spatial_dims: int = 2,
        sigma: Union[Sequence[float], float, Sequence[torch.Tensor],
                     torch.Tensor] = 0.0,
        prob_threshold: float = 0.5,
        box_size: Union[int, Sequence[int]] = 48,
    ) -> None:
        self.sigma = sigma
        self.spatial_dims = spatial_dims
        if self.sigma != 0:
            self.filter = GaussianFilter(spatial_dims=spatial_dims,
                                         sigma=sigma)
        if prob_threshold < 0:
            raise ValueError("prob_threshold should be no less than 0.0.")
        self.prob_threshold = prob_threshold
        if isinstance(box_size, int):
            self.box_size = np.asarray([box_size] * spatial_dims)
        elif len(box_size) != spatial_dims:
            raise ValueError(
                "the sequence length of box_size should be the same as spatial_dims."
            )
        else:
            self.box_size = np.asarray(box_size)
        if self.box_size.min() <= 0:
            raise ValueError("box_size should be larger than 0.")

        self.box_lower_bd = self.box_size // 2
        self.box_upper_bd = self.box_size - self.box_lower_bd

    def __call__(self, prob_map: NdarrayOrTensor):
        """
        prob_map: the input probabilities map, it must have shape (H[, W, ...]).
        """
        if self.sigma != 0:
            if not isinstance(prob_map, torch.Tensor):
                prob_map = torch.as_tensor(prob_map, dtype=torch.float)
            self.filter.to(prob_map.device)
            prob_map = self.filter(prob_map)

        prob_map_shape = prob_map.shape

        outputs = []
        while prob_map.max() > self.prob_threshold:
            max_idx = unravel_index(prob_map.argmax(), prob_map_shape)
            prob_max = prob_map[tuple(max_idx)]
            max_idx = max_idx.cpu().numpy() if isinstance(
                max_idx, torch.Tensor) else max_idx
            prob_max = prob_max.item() if isinstance(
                prob_max, torch.Tensor) else prob_max
            outputs.append([prob_max] + list(max_idx))

            idx_min_range = (max_idx - self.box_lower_bd).clip(0, None)
            idx_max_range = (max_idx + self.box_upper_bd).clip(
                None, prob_map_shape)
            # for each dimension, set values during index ranges to 0
            slices = tuple(
                slice(idx_min_range[i], idx_max_range[i])
                for i in range(self.spatial_dims))
            prob_map[slices] = 0

        return outputs