def __init__( self, sizes: Sequence[Sequence[int]] = ((20, 30, 40),), aspect_ratios: Sequence = (((0.5, 1), (1, 0.5)),), indexing: str = "ij", ) -> None: super().__init__() if not issequenceiterable(sizes[0]): self.sizes = tuple((s,) for s in sizes) else: self.sizes = ensure_tuple(sizes) if not issequenceiterable(aspect_ratios[0]): aspect_ratios = (aspect_ratios,) * len(self.sizes) if len(self.sizes) != len(aspect_ratios): raise ValueError( "len(sizes) and len(aspect_ratios) should be equal. \ It represents the number of feature maps." ) spatial_dims = len(ensure_tuple(aspect_ratios[0][0])) + 1 spatial_dims = look_up_option(spatial_dims, [2, 3]) self.spatial_dims = spatial_dims self.indexing = look_up_option(indexing, ["ij", "xy"]) self.aspect_ratios = aspect_ratios self.cell_anchors = [ self.generate_anchors(size, aspect_ratio) for size, aspect_ratio in zip(self.sizes, aspect_ratios) ]
def ensure_list(vals: Any): """ Returns a list of `vals`. """ if not issequenceiterable(vals) or isinstance(vals, dict): vals = [vals, ] return list(vals)
def ensure_list_rep(vals: Any, dim: int) -> List[Any]: """ Returns a copy of `tup` with `dim` values by either shortened or duplicated input. Raises: ValueError: When ``tup`` is a sequence and ``tup`` length is not ``dim``. """ if not issequenceiterable(vals): return [vals,] * dim elif len(vals) == dim: return list(vals) raise ValueError(f"Sequence must have length {dim}, got {len(vals)}.")
def allow_missing_keys_mode(transform: Union[MapTransform, Compose, Tuple[MapTransform], Tuple[Compose]]): """Temporarily set all MapTransforms to not throw an error if keys are missing. After, revert to original states. Args: transform: either MapTransform or a Compose Example: .. code-block:: python data = {"image": np.arange(16, dtype=float).reshape(1, 4, 4)} t = SpatialPadd(["image", "label"], 10, allow_missing_keys=False) _ = t(data) # would raise exception with allow_missing_keys_mode(t): _ = t(data) # OK! """ # If given a sequence of transforms, Compose them to get a single list if issequenceiterable(transform): transform = Compose(transform) # Get list of MapTransforms transforms = [] if isinstance(transform, MapTransform): transforms = [transform] elif isinstance(transform, Compose): # Only keep contained MapTransforms transforms = [ t for t in transform.flatten().transforms if isinstance(t, MapTransform) ] if len(transforms) == 0: raise TypeError( "allow_missing_keys_mode expects either MapTransform(s) or Compose(s) containing MapTransform(s)" ) # Get the state of each `allow_missing_keys` orig_states = [t.allow_missing_keys for t in transforms] try: # Set all to True for t in transforms: t.allow_missing_keys = True yield finally: # Revert for t, o_s in zip(transforms, orig_states): t.allow_missing_keys = o_s
def decollate(data: Any, idx: int): """Recursively de-collate.""" if isinstance(data, dict): return {k: decollate(v, idx) for k, v in data.items()} if isinstance(data, torch.Tensor): out = data[idx] return torch_to_single(out) if isinstance(data, list): if len(data) == 0: return data if isinstance(data[0], torch.Tensor): return [torch_to_single(d[idx]) for d in data] if issequenceiterable(data[0]): return [decollate(d, idx) for d in data] return data[idx] raise TypeError(f"Not sure how to de-collate type: {type(data)}")
def __init__( self, spatial_dims: int, sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor], truncated: float = 4.0, approx: str = "erf", requires_grad: bool = False, ) -> None: """ Args: spatial_dims: number of spatial dimensions of the input image. must have shape (Batch, channels, H[, W, ...]). sigma: std. could be a single value, or `spatial_dims` number of values. truncated: spreads how many stds. approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace". - ``erf`` approximation interpolates the error function; - ``sampled`` uses a sampled Gaussian kernel; - ``scalespace`` corresponds to https://en.wikipedia.org/wiki/Scale_space_implementation#The_discrete_Gaussian_kernel based on the modified Bessel functions. requires_grad: whether to store the gradients for sigma. if True, `sigma` will be the initial value of the parameters of this module (for example `parameters()` iterator could be used to get the parameters); otherwise this module will fix the kernels using `sigma` as the std. """ if issequenceiterable(sigma): if len(sigma) != spatial_dims: # type: ignore raise ValueError else: sigma = [deepcopy(sigma) for _ in range(spatial_dims)] # type: ignore super().__init__() self.sigma = [ torch.nn.Parameter( torch.as_tensor(s, dtype=torch.float, device=s.device if isinstance(s, torch.Tensor) else None), requires_grad=requires_grad, ) for s in sigma # type: ignore ] self.truncated = truncated self.approx = approx for idx, param in enumerate(self.sigma): self.register_parameter(f"kernel_sigma_{idx}", param)