Ejemplo n.º 1
0
    def __init__(
        self,
        sizes: Sequence[Sequence[int]] = ((20, 30, 40),),
        aspect_ratios: Sequence = (((0.5, 1), (1, 0.5)),),
        indexing: str = "ij",
    ) -> None:
        super().__init__()

        if not issequenceiterable(sizes[0]):
            self.sizes = tuple((s,) for s in sizes)
        else:
            self.sizes = ensure_tuple(sizes)
        if not issequenceiterable(aspect_ratios[0]):
            aspect_ratios = (aspect_ratios,) * len(self.sizes)

        if len(self.sizes) != len(aspect_ratios):
            raise ValueError(
                "len(sizes) and len(aspect_ratios) should be equal. \
                It represents the number of feature maps."
            )

        spatial_dims = len(ensure_tuple(aspect_ratios[0][0])) + 1
        spatial_dims = look_up_option(spatial_dims, [2, 3])
        self.spatial_dims = spatial_dims

        self.indexing = look_up_option(indexing, ["ij", "xy"])

        self.aspect_ratios = aspect_ratios
        self.cell_anchors = [
            self.generate_anchors(size, aspect_ratio) for size, aspect_ratio in zip(self.sizes, aspect_ratios)
        ]
Ejemplo n.º 2
0
    def __call__(self,
                 img: NdarrayOrTensor,
                 randomize: bool = True,
                 device: Optional[torch.device] = None) -> NdarrayOrTensor:
        img = convert_to_tensor(img, track_meta=get_track_meta())
        if randomize:
            self.randomize()

        if not self._do_transform:
            return img

        device = device if device is not None else self.device

        field = self.sfield()

        dgrid = self.grid + field.to(self.grid_dtype)
        dgrid = moveaxis(dgrid, 1, -1)  # type: ignore

        img_t = convert_to_tensor(img[None], torch.float32, device)

        out = grid_sample(
            input=img_t,
            grid=dgrid,
            mode=look_up_option(self.grid_mode, GridSampleMode),
            align_corners=self.grid_align_corners,
            padding_mode=look_up_option(self.grid_padding_mode,
                                        GridSamplePadMode),
        )

        out_t, *_ = convert_to_dst_type(out.squeeze(0), img)

        return out_t
Ejemplo n.º 3
0
def _load_state_dict(model: nn.Module, arch: str, progress: bool):
    """
    This function is used to load pretrained models.
    Adapted from PyTorch Hub 2D version: https://pytorch.org/vision/stable/models.html#id16.

    """
    model_urls = {
        "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth",
        "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth",
        "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth",
    }
    model_url = look_up_option(arch, model_urls, None)
    if model_url is None:
        raise ValueError(
            "only 'densenet121', 'densenet169' and 'densenet201' are supported to load pretrained weights."
        )

    pattern = re.compile(
        r"^(.*denselayer\d+)(\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
    )

    state_dict = load_state_dict_from_url(model_url, progress=progress)
    for key in list(state_dict.keys()):
        res = pattern.match(key)
        if res:
            new_key = res.group(1) + ".layers" + res.group(2) + res.group(3)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]

    model_dict = model.state_dict()
    state_dict = {
        k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape)
    }
    model_dict.update(state_dict)
    model.load_state_dict(model_dict)
Ejemplo n.º 4
0
    def __call__(self, randomize=False) -> torch.Tensor:
        if randomize:
            self.randomize()

        field = self.field.clone()

        if self.spatial_zoom is not None:
            resized_field = interpolate(
                input=field,
                scale_factor=self.spatial_zoom,
                mode=look_up_option(self.mode, InterpolateMode),
                align_corners=self.align_corners,
                recompute_scale_factor=False,
            )

            mina = resized_field.min()
            maxa = resized_field.max()
            minv = self.field.min()
            maxv = self.field.max()

            # faster than rescale_array, this uses in-place operations and doesn't perform unneeded range checks
            norm_field = (resized_field.squeeze(0) - mina).div_(maxa - mina)
            field = norm_field.mul_(maxv - minv).add_(minv)

        return field
Ejemplo n.º 5
0
def _load_state_dict(model: nn.Module, arch: str, progress: bool):
    """
    This function is used to load pretrained models.
    """
    model_url = look_up_option(arch, SE_NET_MODELS, None)
    if model_url is None:
        raise ValueError(
            "only 'senet154', 'se_resnet50', 'se_resnet101',  'se_resnet152', 'se_resnext50_32x4d', "
            +
            "and se_resnext101_32x4d are supported to load pretrained weights."
        )

    pattern_conv = re.compile(r"^(layer[1-4]\.\d\.(?:conv)\d\.)(\w*)$")
    pattern_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:bn)(\d\.)(\w*)$")
    pattern_se = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc1.)(\w*)$")
    pattern_se2 = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc2.)(\w*)$")
    pattern_down_conv = re.compile(
        r"^(layer[1-4]\.\d\.)(?:downsample.0.)(\w*)$")
    pattern_down_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.1.)(\w*)$")

    if isinstance(model_url, dict):
        download_url(model_url["url"], filepath=model_url["filename"])
        state_dict = torch.load(model_url["filename"], map_location=None)
    else:
        state_dict = load_state_dict_from_url(model_url, progress=progress)
    for key in list(state_dict.keys()):
        new_key = None
        if pattern_conv.match(key):
            new_key = re.sub(pattern_conv, r"\1conv.\2", key)
        elif pattern_bn.match(key):
            new_key = re.sub(pattern_bn, r"\1conv\2adn.N.\3", key)
        elif pattern_se.match(key):
            state_dict[key] = state_dict[key].squeeze()
            new_key = re.sub(pattern_se, r"\1se_layer.fc.0.\2", key)
        elif pattern_se2.match(key):
            state_dict[key] = state_dict[key].squeeze()
            new_key = re.sub(pattern_se2, r"\1se_layer.fc.2.\2", key)
        elif pattern_down_conv.match(key):
            new_key = re.sub(pattern_down_conv, r"\1project.conv.\2", key)
        elif pattern_down_bn.match(key):
            new_key = re.sub(pattern_down_bn, r"\1project.adn.N.\2", key)
        if new_key:
            state_dict[new_key] = state_dict[key]
            del state_dict[key]

    model_dict = model.state_dict()
    state_dict = {
        k: v
        for k, v in state_dict.items()
        if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape)
    }
    model_dict.update(state_dict)
    model.load_state_dict(model_dict)
Ejemplo n.º 6
0
    def __init__(
        self,
        feature_map_scales: Union[Sequence[int], Sequence[float]] = (1, 2, 4, 8),
        base_anchor_shapes: Union[Sequence[Sequence[int]], Sequence[Sequence[float]]] = (
            (32, 32, 32),
            (48, 20, 20),
            (20, 48, 20),
            (20, 20, 48),
        ),
        indexing: str = "ij",
    ) -> None:

        nn.Module.__init__(self)

        spatial_dims = len(base_anchor_shapes[0])
        spatial_dims = look_up_option(spatial_dims, [2, 3])
        self.spatial_dims = spatial_dims

        self.indexing = look_up_option(indexing, ["ij", "xy"])

        base_anchor_shapes_t = torch.Tensor(base_anchor_shapes)
        self.cell_anchors = [self.generate_anchors_using_shape(s * base_anchor_shapes_t) for s in feature_map_scales]
Ejemplo n.º 7
0
 def __init__(
     self,
     device: torch.device,
     val_data_loader: Iterable | DataLoader,
     epoch_length: int | None = None,
     non_blocking: bool = False,
     prepare_batch: Callable = default_prepare_batch,
     iteration_update: Callable[[Engine, Any], Any] | None = None,
     postprocessing: Transform | None = None,
     key_val_metric: dict[str, Metric] | None = None,
     additional_metrics: dict[str, Metric] | None = None,
     metric_cmp_fn: Callable = default_metric_cmp_fn,
     val_handlers: Sequence | None = None,
     amp: bool = False,
     mode: ForwardMode | str = ForwardMode.EVAL,
     event_names: list[str | EventEnum] | None = None,
     event_to_attr: dict | None = None,
     decollate: bool = True,
     to_kwargs: dict | None = None,
     amp_kwargs: dict | None = None,
 ) -> None:
     super().__init__(
         device=device,
         max_epochs=1,
         data_loader=val_data_loader,
         epoch_length=epoch_length,
         non_blocking=non_blocking,
         prepare_batch=prepare_batch,
         iteration_update=iteration_update,
         postprocessing=postprocessing,
         key_metric=key_val_metric,
         additional_metrics=additional_metrics,
         metric_cmp_fn=metric_cmp_fn,
         handlers=val_handlers,
         amp=amp,
         event_names=event_names,
         event_to_attr=event_to_attr,
         decollate=decollate,
         to_kwargs=to_kwargs,
         amp_kwargs=amp_kwargs,
     )
     mode = look_up_option(mode, ForwardMode)
     if mode == ForwardMode.EVAL:
         self.mode = eval_mode
     elif mode == ForwardMode.TRAIN:
         self.mode = train_mode
     else:
         raise ValueError(
             f"unsupported mode: {mode}, should be 'eval' or 'train'.")
Ejemplo n.º 8
0
    def __init__(
        self,
        kernel_type: str = "gaussian",
        num_bins: int = 23,
        sigma_ratio: float = 0.5,
        reduction: Union[LossReduction, str] = LossReduction.MEAN,
        smooth_nr: float = 1e-7,
        smooth_dr: float = 1e-7,
    ) -> None:
        """
        Args:
            kernel_type: {``"gaussian"``, ``"b-spline"``}
                ``"gaussian"``: adapted from DeepReg
                Reference: https://dspace.mit.edu/handle/1721.1/123142, Section 3.1, equation 3.1-3.5, Algorithm 1.
                ``"b-spline"``: based on the method of Mattes et al [1,2] and adapted from ITK
                References:
                  [1] "Nonrigid multimodality image registration"
                      D. Mattes, D. R. Haynor, H. Vesselle, T. Lewellen and W. Eubank
                      Medical Imaging 2001: Image Processing, 2001, pp. 1609-1620.
                  [2] "PET-CT Image Registration in the Chest Using Free-form Deformations"
                      D. Mattes, D. R. Haynor, H. Vesselle, T. Lewellen and W. Eubank
                      IEEE Transactions in Medical Imaging. Vol.22, No.1,
                      January 2003. pp.120-128.

            num_bins: number of bins for intensity
            sigma_ratio: a hyper param for gaussian function
            reduction: {``"none"``, ``"mean"``, ``"sum"``}
                Specifies the reduction to apply to the output. Defaults to ``"mean"``.

                - ``"none"``: no reduction will be applied.
                - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
                - ``"sum"``: the output will be summed.
            smooth_nr: a small constant added to the numerator to avoid nan.
            smooth_dr: a small constant added to the denominator to avoid nan.
        """
        super().__init__(reduction=LossReduction(reduction).value)
        if num_bins <= 0:
            raise ValueError("num_bins must > 0, got {num_bins}")
        bin_centers = torch.linspace(0.0, 1.0, num_bins)  # (num_bins,)
        sigma = torch.mean(bin_centers[1:] - bin_centers[:-1]) * sigma_ratio
        self.kernel_type = look_up_option(kernel_type, ["gaussian", "b-spline"])
        self.num_bins = num_bins
        self.kernel_type = kernel_type
        if self.kernel_type == "gaussian":
            self.preterm = 1 / (2 * sigma**2)
            self.bin_centers = bin_centers[None, None, ...]
        self.smooth_nr = float(smooth_nr)
        self.smooth_dr = float(smooth_dr)
Ejemplo n.º 9
0
    def __init__(
        self,
        spatial_dims: int = 3,
        kernel_size: int = 3,
        kernel_type: str = "rectangular",
        reduction: Union[LossReduction, str] = LossReduction.MEAN,
        smooth_nr: float = 1e-5,
        smooth_dr: float = 1e-5,
        ndim: Optional[int] = None,
    ) -> None:
        """
        Args:
            spatial_dims: number of spatial dimensions, {``1``, ``2``, ``3``}. Defaults to 3.
            kernel_size: kernel spatial size, must be odd.
            kernel_type: {``"rectangular"``, ``"triangular"``, ``"gaussian"``}. Defaults to ``"rectangular"``.
            reduction: {``"none"``, ``"mean"``, ``"sum"``}
                Specifies the reduction to apply to the output. Defaults to ``"mean"``.

                - ``"none"``: no reduction will be applied.
                - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
                - ``"sum"``: the output will be summed.
            smooth_nr: a small constant added to the numerator to avoid nan.
            smooth_dr: a small constant added to the denominator to avoid nan.

        .. deprecated:: 0.6.0
            ``ndim`` is deprecated, use ``spatial_dims``.
        """
        super().__init__(reduction=LossReduction(reduction).value)

        if ndim is not None:
            spatial_dims = ndim
        self.ndim = spatial_dims
        if self.ndim not in {1, 2, 3}:
            raise ValueError(f"Unsupported ndim: {self.ndim}-d, only 1-d, 2-d, and 3-d inputs are supported")

        self.kernel_size = kernel_size
        if self.kernel_size % 2 == 0:
            raise ValueError(f"kernel_size must be odd, got {self.kernel_size}")

        _kernel = look_up_option(kernel_type, kernel_dict)
        self.kernel = _kernel(self.kernel_size)
        self.kernel_vol = self.get_kernel_vol()

        self.smooth_nr = float(smooth_nr)
        self.smooth_dr = float(smooth_dr)
Ejemplo n.º 10
0
    def _make_layer(
        self,
        block: Type[Union[ResNetBlock, ResNetBottleneck]],
        planes: int,
        blocks: int,
        spatial_dims: int,
        shortcut_type: str,
        stride: int = 1,
    ) -> nn.Sequential:

        conv_type: Callable = Conv[Conv.CONV, spatial_dims]
        norm_type: Callable = Norm[Norm.BATCH, spatial_dims]

        downsample: Union[nn.Module, partial, None] = None
        if stride != 1 or self.in_planes != planes * block.expansion:
            if look_up_option(shortcut_type, {"A", "B"}) == "A":
                downsample = partial(
                    self._downsample_basic_block,
                    planes=planes * block.expansion,
                    stride=stride,
                    spatial_dims=spatial_dims,
                )
            else:
                downsample = nn.Sequential(
                    conv_type(self.in_planes,
                              planes * block.expansion,
                              kernel_size=1,
                              stride=stride),
                    norm_type(planes * block.expansion),
                )

        layers = [
            block(in_planes=self.in_planes,
                  planes=planes,
                  spatial_dims=spatial_dims,
                  stride=stride,
                  downsample=downsample)
        ]

        self.in_planes = planes * block.expansion
        for _i in range(1, blocks):
            layers.append(
                block(self.in_planes, planes, spatial_dims=spatial_dims))

        return nn.Sequential(*layers)
Ejemplo n.º 11
0
 def __init__(
     self,
     device: torch.device,
     val_data_loader: Union[Iterable, DataLoader],
     epoch_length: Optional[int] = None,
     non_blocking: bool = False,
     prepare_batch: Callable = default_prepare_batch,
     iteration_update: Optional[Callable] = None,
     postprocessing: Optional[Transform] = None,
     key_val_metric: Optional[Dict[str, Metric]] = None,
     additional_metrics: Optional[Dict[str, Metric]] = None,
     metric_cmp_fn: Callable = default_metric_cmp_fn,
     val_handlers: Optional[Sequence] = None,
     amp: bool = False,
     mode: Union[ForwardMode, str] = ForwardMode.EVAL,
     event_names: Optional[List[Union[str, EventEnum]]] = None,
     event_to_attr: Optional[dict] = None,
     decollate: bool = True,
 ) -> None:
     super().__init__(
         device=device,
         max_epochs=1,
         data_loader=val_data_loader,
         epoch_length=epoch_length,
         non_blocking=non_blocking,
         prepare_batch=prepare_batch,
         iteration_update=iteration_update,
         postprocessing=postprocessing,
         key_metric=key_val_metric,
         additional_metrics=additional_metrics,
         metric_cmp_fn=metric_cmp_fn,
         handlers=val_handlers,
         amp=amp,
         event_names=event_names,
         event_to_attr=event_to_attr,
         decollate=decollate,
     )
     self.mode = look_up_option(mode, ForwardMode)
     if mode == ForwardMode.EVAL:
         self.mode = eval_mode
     elif mode == ForwardMode.TRAIN:
         self.mode = train_mode
     else:
         raise ValueError(f"unsupported mode: {mode}, should be 'eval' or 'train'.")
Ejemplo n.º 12
0
def _load_state_dict(model: nn.Module, arch: str, progress: bool,
                     adv_prop: bool) -> None:
    if adv_prop:
        arch = arch.split("efficientnet-")[-1] + "-ap"
    model_url = look_up_option(arch, url_map, None)
    if model_url is None:
        print(f"pretrained weights of {arch} is not provided")
    else:
        # load state dict from url
        model_url = url_map[arch]
        pretrain_state_dict = model_zoo.load_url(model_url, progress=progress)
        model_state_dict = model.state_dict()

        pattern = re.compile(r"(.+)\.\d+(\.\d+\..+)")
        for key, value in model_state_dict.items():
            pretrain_key = re.sub(pattern, r"\1\2", key)
            if pretrain_key in pretrain_state_dict and value.shape == pretrain_state_dict[
                    pretrain_key].shape:
                model_state_dict[key] = pretrain_state_dict[pretrain_key]

        model.load_state_dict(model_state_dict)
Ejemplo n.º 13
0
    def __init__(
        self,
        in_channels: int,
        img_size: Union[Sequence[int], int],
        patch_size: Union[Sequence[int], int],
        hidden_size: int,
        num_heads: int,
        pos_embed: str,
        dropout_rate: float = 0.0,
        spatial_dims: int = 3,
    ) -> None:
        """
        Args:
            in_channels: dimension of input channels.
            img_size: dimension of input image.
            patch_size: dimension of patch size.
            hidden_size: dimension of hidden layer.
            num_heads: number of attention heads.
            pos_embed: position embedding layer type.
            dropout_rate: faction of the input units to drop.
            spatial_dims: number of spatial dimensions.


        """

        super().__init__()

        if not (0 <= dropout_rate <= 1):
            raise ValueError("dropout_rate should be between 0 and 1.")

        if hidden_size % num_heads != 0:
            raise ValueError("hidden size should be divisible by num_heads.")

        self.pos_embed = look_up_option(pos_embed, SUPPORTED_EMBEDDING_TYPES)

        img_size = ensure_tuple_rep(img_size, spatial_dims)
        patch_size = ensure_tuple_rep(patch_size, spatial_dims)
        for m, p in zip(img_size, patch_size):
            if m < p:
                raise ValueError("patch_size should be smaller than img_size.")
            if self.pos_embed == "perceptron" and m % p != 0:
                raise ValueError(
                    "patch_size should be divisible by img_size for perceptron."
                )
        self.n_patches = np.prod(
            [im_d // p_d for im_d, p_d in zip(img_size, patch_size)])
        self.patch_dim = int(in_channels * np.prod(patch_size))

        self.patch_embeddings: nn.Module
        if self.pos_embed == "conv":
            self.patch_embeddings = Conv[Conv.CONV, spatial_dims](
                in_channels=in_channels,
                out_channels=hidden_size,
                kernel_size=patch_size,
                stride=patch_size)
        elif self.pos_embed == "perceptron":
            # for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)"
            chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims]
            from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars)
            to_chars = f"b ({' '.join([c[0] for c in chars])}) ({' '.join([c[1] for c in chars])} c)"
            axes_len = {f"p{i+1}": p for i, p in enumerate(patch_size)}
            self.patch_embeddings = nn.Sequential(
                Rearrange(f"{from_chars} -> {to_chars}", **axes_len),
                nn.Linear(self.patch_dim, hidden_size))
        self.position_embeddings = nn.Parameter(
            torch.zeros(1, self.n_patches, hidden_size))
        self.dropout = nn.Dropout(dropout_rate)
        trunc_normal_(self.position_embeddings,
                      mean=0.0,
                      std=0.02,
                      a=-2.0,
                      b=2.0)
        self.apply(self._init_weights)
Ejemplo n.º 14
0
def dtype_numpy_to_torch(dtype):
    """Convert a numpy dtype to its torch equivalent."""
    # np dtypes can be given as np.float32 and np.dtype(np.float32) so unify them
    dtype = np.dtype(dtype) if isinstance(dtype, (type, str)) else dtype
    return look_up_option(dtype, _np_to_torch_dtype)
Ejemplo n.º 15
0
def dtype_torch_to_numpy(dtype):
    """Convert a torch dtype to its numpy equivalent."""
    return look_up_option(dtype, _torch_to_np_dtype)