Пример #1
0
def apply_filter(x: torch.Tensor, kernel: torch.Tensor, **kwargs) -> torch.Tensor:
    """
    Filtering `x` with `kernel` independently for each batch and channel respectively.

    Args:
        x: the input image, must have shape (batch, channels, H[, W, D]).
        kernel: `kernel` must at least have the spatial shape (H_k[, W_k, D_k]).
            `kernel` shape must be broadcastable to the `batch` and `channels` dimensions of `x`.
        kwargs: keyword arguments passed to `conv*d()` functions.

    Returns:
        The filtered `x`.

    Examples:

    .. code-block:: python

        >>> import torch
        >>> from monai.networks.layers import apply_filter
        >>> img = torch.rand(2, 5, 10, 10)  # batch_size 2, channels 5, 10x10 2D images
        >>> out = apply_filter(img, torch.rand(3, 3))   # spatial kernel
        >>> out = apply_filter(img, torch.rand(5, 3, 3))  # channel-wise kernels
        >>> out = apply_filter(img, torch.rand(2, 5, 3, 3))  # batch-, channel-wise kernels

    """
    if not isinstance(x, torch.Tensor):
        raise TypeError(f"x must be a torch.Tensor but is {type(x).__name__}.")
    batch, chns, *spatials = x.shape
    n_spatial = len(spatials)
    if n_spatial > 3:
        raise NotImplementedError(f"Only spatial dimensions up to 3 are supported but got {n_spatial}.")
    k_size = len(kernel.shape)
    if k_size < n_spatial or k_size > n_spatial + 2:
        raise ValueError(
            f"kernel must have {n_spatial} ~ {n_spatial + 2} dimensions to match the input shape {x.shape}."
        )
    kernel = kernel.to(x)
    # broadcast kernel size to (batch chns, spatial_kernel_size)
    kernel = kernel.expand(batch, chns, *kernel.shape[(k_size - n_spatial) :])
    kernel = kernel.reshape(-1, 1, *kernel.shape[2:])  # group=1
    x = x.view(1, kernel.shape[0], *spatials)
    conv = [F.conv1d, F.conv2d, F.conv3d][n_spatial - 1]
    if "padding" not in kwargs:
        if pytorch_after(1, 10):
            kwargs["padding"] = "same"
        else:
            # even-sized kernels are not supported
            kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]]

    if "stride" not in kwargs:
        kwargs["stride"] = 1
    output = conv(x, kernel, groups=kernel.shape[0], bias=None, **kwargs)
    return output.view(batch, chns, *output.shape[2:])
Пример #2
0
    def __init__(
        self,
        nn_module,
        target_layer_names: Union[str, Sequence[str]],
        register_forward: bool = False,
        register_backward: bool = False,
    ):
        """

        Args:
            nn_module: the model to be wrapped.
            target_layer_names: the names of the layer to cache.
            register_forward: whether to cache the forward pass output corresponding to `target_layer_names`.
            register_backward: whether to cache the backward pass output corresponding to `target_layer_names`.
        """
        self.model = nn_module
        self.target_layers = ensure_tuple(target_layer_names)

        self.gradients: Dict[str, torch.Tensor] = {}
        self.activations: Dict[str, torch.Tensor] = {}
        self.score = None
        self.class_idx = None
        self.register_backward = register_backward
        self.register_forward = register_forward

        _registered = []
        for name, mod in nn_module.named_modules():
            if name not in self.target_layers:
                continue
            _registered.append(name)
            if self.register_backward:
                if pytorch_after(1, 8):
                    if "inplace" in mod.__dict__ and mod.__dict__["inplace"]:
                        # inplace=True causes errors for register_full_backward_hook
                        mod.__dict__["inplace"] = False
                    mod.register_full_backward_hook(self.backward_hook(name))
                else:
                    mod.register_backward_hook(self.backward_hook(name))
            if self.register_forward:
                mod.register_forward_hook(self.forward_hook(name))
        if self.target_layers and (len(_registered) != len(
                self.target_layers)):
            warnings.warn(
                f"Not all target_layers exist in the network module: targets: {self.target_layers}."
            )
 def save_func(engine):
     if pytorch_after(1, 9, 1):
         for m in engine.state.output:
             saver(m)
     else:
         saver(engine.state.output[0])
Пример #4
0
    def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
        if batchdata is None:
            raise ValueError("Must provide batch data for current iteration.")
        batch = self.prepare_batch(batchdata, engine.state.device,
                                   engine.non_blocking)
        if len(batch) == 2:
            inputs, targets = batch
            args: Tuple = ()
            kwargs: Dict = {}
        else:
            inputs, targets, args, kwargs = batch
        # put iteration outputs into engine.state
        engine.state.output = {
            self.keys["IMAGE"]: inputs,
            self.keys["LABEL"]: targets
        }

        def _compute_pred_loss():
            preds = self.inferer(inputs, self.network, *args, **kwargs)
            engine.state.output[self.keys["PRED"]] = preds
            engine.fire_event(IterationEvents.FORWARD_COMPLETED)
            if not isinstance(preds, tuple):
                raise ValueError(
                    "Predictions must be tuple in multi-task framework",
                    f"but got {type(engine.state.output[self.keys['PRED']])}",
                )
            if not isinstance(targets, tuple):
                raise ValueError(
                    f"Targets must be tuple in multi-task framework, but got {type(targets)}"
                )
            if len(preds) != len(targets):
                raise ValueError(
                    f"Predictions len must equal to targets, but got {len(preds)} != {len(targets)}"
                )

            loss = self.loss_function(preds, targets)

            engine.state.output[self.keys["LOSS"]] = loss
            engine.fire_event(IterationEvents.LOSS_COMPLETED)

        self.network.train()
        # `set_to_none` only work from PyTorch 1.7.0
        if not pytorch_after(1, 7):
            self.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad(set_to_none=self.optim_set_to_none)

        if self.amp and self.scaler is not None:
            with torch.cuda.amp.autocast():
                _compute_pred_loss()
            self.scaler.scale(
                engine.state.output[self.keys["LOSS"]]).backward()
            engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            _compute_pred_loss()
            engine.state.output[self.keys["LOSS"]].backward()
            engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
            self.optimizer.step()
        engine.fire_event(IterationEvents.MODEL_COMPLETED)

        return engine.state.output
Пример #5
0
 def test_compare(self, a, b, p, current, expected=True):
     """Test pytorch_after with a and b"""
     self.assertEqual(pytorch_after(a, b, p, current), expected)