def test_train(self, input_args): input_dims = input_args.get("dims", (2, 3, 8)) device = (torch.device("cuda") if (input_args.get("device") == "cuda" and torch.cuda.is_available()) else torch.device("cpu:0")) base = torch.ones(*input_dims).to(device) gt = torch.tensor(input_args["gt"], requires_grad=False).to(device) g_type = input_args["type"] lr = input_args.get("lr", 0.1) init_sigma = input_args.get("init", 1.0) # static filter to generate a target spatial_dims = len(base.shape) - 2 filtering = GaussianFilter(spatial_dims=spatial_dims, sigma=gt, approx=g_type, requires_grad=False) filtering.to(device) target = filtering(base) self.assertFalse(filtering.sigma[0].requires_grad) # build trainable init_sigma = torch.tensor(init_sigma).to(device) trainable = GaussianFilter(spatial_dims=spatial_dims, sigma=init_sigma, approx=g_type, requires_grad=True) trainable.to(device) self.assertTrue(trainable.sigma[0].requires_grad) # train optimizer = torch.optim.Adam(trainable.parameters(), lr=lr) for s in range(1000): optimizer.zero_grad() pred = trainable(base) loss = torch.square(pred - target).mean() loss.backward() if (s + 1) % 50 == 0: var = list(trainable.parameters())[0] print(f"step {s} loss {loss}") print(var, var.grad) if loss.item() < 1e-7: break optimizer.step() # check the result print(s, gt) for idx, s in enumerate(trainable.sigma): np.testing.assert_allclose( s.cpu().item(), gt.cpu() if len(gt.shape) == 0 else gt[idx].cpu().item(), rtol=1e-2)
class ProbNMS(Transform): """ Performs probability based non-maximum suppression (NMS) on the probabilities map via iteratively selecting the coordinate with highest probability and then move it as well as its surrounding values. The remove range is determined by the parameter `box_size`. If multiple coordinates have the same highest probability, only one of them will be selected. Args: spatial_dims: number of spatial dimensions of the input probabilities map. Defaults to 2. sigma: the standard deviation for gaussian filter. It could be a single value, or `spatial_dims` number of values. Defaults to 0.0. prob_threshold: the probability threshold, the function will stop searching if the highest probability is no larger than the threshold. The value should be no less than 0.0. Defaults to 0.5. box_size: the box size (in pixel) to be removed around the the pixel with the maximum probability. It can be an integer that defines the size of a square or cube, or a list containing different values for each dimensions. Defaults to 48. Return: a list of selected lists, where inner lists contain probability and coordinates. For example, for 3D input, the inner lists are in the form of [probability, x, y, z]. Raises: ValueError: When ``prob_threshold`` is less than 0.0. ValueError: When ``box_size`` is a list or tuple, and its length is not equal to `spatial_dims`. ValueError: When ``box_size`` has a less than 1 value. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] def __init__( self, spatial_dims: int = 2, sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor] = 0.0, prob_threshold: float = 0.5, box_size: Union[int, Sequence[int]] = 48, ) -> None: self.sigma = sigma self.spatial_dims = spatial_dims if self.sigma != 0: self.filter = GaussianFilter(spatial_dims=spatial_dims, sigma=sigma) if prob_threshold < 0: raise ValueError("prob_threshold should be no less than 0.0.") self.prob_threshold = prob_threshold if isinstance(box_size, int): self.box_size = np.asarray([box_size] * spatial_dims) elif len(box_size) != spatial_dims: raise ValueError( "the sequence length of box_size should be the same as spatial_dims." ) else: self.box_size = np.asarray(box_size) if self.box_size.min() <= 0: raise ValueError("box_size should be larger than 0.") self.box_lower_bd = self.box_size // 2 self.box_upper_bd = self.box_size - self.box_lower_bd def __call__(self, prob_map: NdarrayOrTensor): """ prob_map: the input probabilities map, it must have shape (H[, W, ...]). """ if self.sigma != 0: if not isinstance(prob_map, torch.Tensor): prob_map = torch.as_tensor(prob_map, dtype=torch.float) self.filter.to(prob_map.device) prob_map = self.filter(prob_map) prob_map_shape = prob_map.shape outputs = [] while prob_map.max() > self.prob_threshold: max_idx = unravel_index(prob_map.argmax(), prob_map_shape) prob_max = prob_map[tuple(max_idx)] max_idx = max_idx.cpu().numpy() if isinstance( max_idx, torch.Tensor) else max_idx prob_max = prob_max.item() if isinstance( prob_max, torch.Tensor) else prob_max outputs.append([prob_max] + list(max_idx)) idx_min_range = (max_idx - self.box_lower_bd).clip(0, None) idx_max_range = (max_idx + self.box_upper_bd).clip( None, prob_map_shape) # for each dimension, set values during index ranges to 0 slices = tuple( slice(idx_min_range[i], idx_max_range[i]) for i in range(self.spatial_dims)) prob_map[slices] = 0 return outputs