Beispiel #1
0
    def __init__(
        self,
        sizes: Sequence[Sequence[int]] = ((20, 30, 40),),
        aspect_ratios: Sequence = (((0.5, 1), (1, 0.5)),),
        indexing: str = "ij",
    ) -> None:
        super().__init__()

        if not issequenceiterable(sizes[0]):
            self.sizes = tuple((s,) for s in sizes)
        else:
            self.sizes = ensure_tuple(sizes)
        if not issequenceiterable(aspect_ratios[0]):
            aspect_ratios = (aspect_ratios,) * len(self.sizes)

        if len(self.sizes) != len(aspect_ratios):
            raise ValueError(
                "len(sizes) and len(aspect_ratios) should be equal. \
                It represents the number of feature maps."
            )

        spatial_dims = len(ensure_tuple(aspect_ratios[0][0])) + 1
        spatial_dims = look_up_option(spatial_dims, [2, 3])
        self.spatial_dims = spatial_dims

        self.indexing = look_up_option(indexing, ["ij", "xy"])

        self.aspect_ratios = aspect_ratios
        self.cell_anchors = [
            self.generate_anchors(size, aspect_ratio) for size, aspect_ratio in zip(self.sizes, aspect_ratios)
        ]
Beispiel #2
0
def ensure_list(vals: Any):
    """
    Returns a list of `vals`.
    """
    if not issequenceiterable(vals) or isinstance(vals, dict):
        vals = [vals, ]

    return list(vals)
Beispiel #3
0
def ensure_list_rep(vals: Any, dim: int) -> List[Any]:
    """
    Returns a copy of `tup` with `dim` values by either shortened or duplicated input.

    Raises:
        ValueError: When ``tup`` is a sequence and ``tup`` length is not ``dim``.
    """
    if not issequenceiterable(vals):
        return [vals,] * dim
    elif len(vals) == dim:
        return list(vals)

    raise ValueError(f"Sequence must have length {dim}, got {len(vals)}.")
Beispiel #4
0
def allow_missing_keys_mode(transform: Union[MapTransform, Compose,
                                             Tuple[MapTransform],
                                             Tuple[Compose]]):
    """Temporarily set all MapTransforms to not throw an error if keys are missing. After, revert to original states.

    Args:
        transform: either MapTransform or a Compose

    Example:

    .. code-block:: python

        data = {"image": np.arange(16, dtype=float).reshape(1, 4, 4)}
        t = SpatialPadd(["image", "label"], 10, allow_missing_keys=False)
        _ = t(data)  # would raise exception
        with allow_missing_keys_mode(t):
            _ = t(data)  # OK!
    """
    # If given a sequence of transforms, Compose them to get a single list
    if issequenceiterable(transform):
        transform = Compose(transform)

    # Get list of MapTransforms
    transforms = []
    if isinstance(transform, MapTransform):
        transforms = [transform]
    elif isinstance(transform, Compose):
        # Only keep contained MapTransforms
        transforms = [
            t for t in transform.flatten().transforms
            if isinstance(t, MapTransform)
        ]
    if len(transforms) == 0:
        raise TypeError(
            "allow_missing_keys_mode expects either MapTransform(s) or Compose(s) containing MapTransform(s)"
        )

    # Get the state of each `allow_missing_keys`
    orig_states = [t.allow_missing_keys for t in transforms]

    try:
        # Set all to True
        for t in transforms:
            t.allow_missing_keys = True
        yield
    finally:
        # Revert
        for t, o_s in zip(transforms, orig_states):
            t.allow_missing_keys = o_s
Beispiel #5
0
 def decollate(data: Any, idx: int):
     """Recursively de-collate."""
     if isinstance(data, dict):
         return {k: decollate(v, idx) for k, v in data.items()}
     if isinstance(data, torch.Tensor):
         out = data[idx]
         return torch_to_single(out)
     if isinstance(data, list):
         if len(data) == 0:
             return data
         if isinstance(data[0], torch.Tensor):
             return [torch_to_single(d[idx]) for d in data]
         if issequenceiterable(data[0]):
             return [decollate(d, idx) for d in data]
         return data[idx]
     raise TypeError(f"Not sure how to de-collate type: {type(data)}")
Beispiel #6
0
    def __init__(
        self,
        spatial_dims: int,
        sigma: Union[Sequence[float], float, Sequence[torch.Tensor], torch.Tensor],
        truncated: float = 4.0,
        approx: str = "erf",
        requires_grad: bool = False,
    ) -> None:
        """
        Args:
            spatial_dims: number of spatial dimensions of the input image.
                must have shape (Batch, channels, H[, W, ...]).
            sigma: std. could be a single value, or `spatial_dims` number of values.
            truncated: spreads how many stds.
            approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace".

                - ``erf`` approximation interpolates the error function;
                - ``sampled`` uses a sampled Gaussian kernel;
                - ``scalespace`` corresponds to
                  https://en.wikipedia.org/wiki/Scale_space_implementation#The_discrete_Gaussian_kernel
                  based on the modified Bessel functions.

            requires_grad: whether to store the gradients for sigma.
                if True, `sigma` will be the initial value of the parameters of this module
                (for example `parameters()` iterator could be used to get the parameters);
                otherwise this module will fix the kernels using `sigma` as the std.
        """
        if issequenceiterable(sigma):
            if len(sigma) != spatial_dims:  # type: ignore
                raise ValueError
        else:
            sigma = [deepcopy(sigma) for _ in range(spatial_dims)]  # type: ignore
        super().__init__()
        self.sigma = [
            torch.nn.Parameter(
                torch.as_tensor(s, dtype=torch.float, device=s.device if isinstance(s, torch.Tensor) else None),
                requires_grad=requires_grad,
            )
            for s in sigma  # type: ignore
        ]
        self.truncated = truncated
        self.approx = approx
        for idx, param in enumerate(self.sigma):
            self.register_parameter(f"kernel_sigma_{idx}", param)