def apply_filter(x: torch.Tensor, kernel: torch.Tensor, **kwargs) -> torch.Tensor: """ Filtering `x` with `kernel` independently for each batch and channel respectively. Args: x: the input image, must have shape (batch, channels, H[, W, D]). kernel: `kernel` must at least have the spatial shape (H_k[, W_k, D_k]). `kernel` shape must be broadcastable to the `batch` and `channels` dimensions of `x`. kwargs: keyword arguments passed to `conv*d()` functions. Returns: The filtered `x`. Examples: .. code-block:: python >>> import torch >>> from monai.networks.layers import apply_filter >>> img = torch.rand(2, 5, 10, 10) # batch_size 2, channels 5, 10x10 2D images >>> out = apply_filter(img, torch.rand(3, 3)) # spatial kernel >>> out = apply_filter(img, torch.rand(5, 3, 3)) # channel-wise kernels >>> out = apply_filter(img, torch.rand(2, 5, 3, 3)) # batch-, channel-wise kernels """ if not isinstance(x, torch.Tensor): raise TypeError(f"x must be a torch.Tensor but is {type(x).__name__}.") batch, chns, *spatials = x.shape n_spatial = len(spatials) if n_spatial > 3: raise NotImplementedError(f"Only spatial dimensions up to 3 are supported but got {n_spatial}.") k_size = len(kernel.shape) if k_size < n_spatial or k_size > n_spatial + 2: raise ValueError( f"kernel must have {n_spatial} ~ {n_spatial + 2} dimensions to match the input shape {x.shape}." ) kernel = kernel.to(x) # broadcast kernel size to (batch chns, spatial_kernel_size) kernel = kernel.expand(batch, chns, *kernel.shape[(k_size - n_spatial) :]) kernel = kernel.reshape(-1, 1, *kernel.shape[2:]) # group=1 x = x.view(1, kernel.shape[0], *spatials) conv = [F.conv1d, F.conv2d, F.conv3d][n_spatial - 1] if "padding" not in kwargs: if pytorch_after(1, 10): kwargs["padding"] = "same" else: # even-sized kernels are not supported kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]] if "stride" not in kwargs: kwargs["stride"] = 1 output = conv(x, kernel, groups=kernel.shape[0], bias=None, **kwargs) return output.view(batch, chns, *output.shape[2:])
def __init__( self, nn_module, target_layer_names: Union[str, Sequence[str]], register_forward: bool = False, register_backward: bool = False, ): """ Args: nn_module: the model to be wrapped. target_layer_names: the names of the layer to cache. register_forward: whether to cache the forward pass output corresponding to `target_layer_names`. register_backward: whether to cache the backward pass output corresponding to `target_layer_names`. """ self.model = nn_module self.target_layers = ensure_tuple(target_layer_names) self.gradients: Dict[str, torch.Tensor] = {} self.activations: Dict[str, torch.Tensor] = {} self.score = None self.class_idx = None self.register_backward = register_backward self.register_forward = register_forward _registered = [] for name, mod in nn_module.named_modules(): if name not in self.target_layers: continue _registered.append(name) if self.register_backward: if pytorch_after(1, 8): if "inplace" in mod.__dict__ and mod.__dict__["inplace"]: # inplace=True causes errors for register_full_backward_hook mod.__dict__["inplace"] = False mod.register_full_backward_hook(self.backward_hook(name)) else: mod.register_backward_hook(self.backward_hook(name)) if self.register_forward: mod.register_forward_hook(self.forward_hook(name)) if self.target_layers and (len(_registered) != len( self.target_layers)): warnings.warn( f"Not all target_layers exist in the network module: targets: {self.target_layers}." )
def save_func(engine): if pytorch_after(1, 9, 1): for m in engine.state.output: saver(m) else: saver(engine.state.output[0])
def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): if batchdata is None: raise ValueError("Must provide batch data for current iteration.") batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking) if len(batch) == 2: inputs, targets = batch args: Tuple = () kwargs: Dict = {} else: inputs, targets, args, kwargs = batch # put iteration outputs into engine.state engine.state.output = { self.keys["IMAGE"]: inputs, self.keys["LABEL"]: targets } def _compute_pred_loss(): preds = self.inferer(inputs, self.network, *args, **kwargs) engine.state.output[self.keys["PRED"]] = preds engine.fire_event(IterationEvents.FORWARD_COMPLETED) if not isinstance(preds, tuple): raise ValueError( "Predictions must be tuple in multi-task framework", f"but got {type(engine.state.output[self.keys['PRED']])}", ) if not isinstance(targets, tuple): raise ValueError( f"Targets must be tuple in multi-task framework, but got {type(targets)}" ) if len(preds) != len(targets): raise ValueError( f"Predictions len must equal to targets, but got {len(preds)} != {len(targets)}" ) loss = self.loss_function(preds, targets) engine.state.output[self.keys["LOSS"]] = loss engine.fire_event(IterationEvents.LOSS_COMPLETED) self.network.train() # `set_to_none` only work from PyTorch 1.7.0 if not pytorch_after(1, 7): self.optimizer.zero_grad() else: self.optimizer.zero_grad(set_to_none=self.optim_set_to_none) if self.amp and self.scaler is not None: with torch.cuda.amp.autocast(): _compute_pred_loss() self.scaler.scale( engine.state.output[self.keys["LOSS"]]).backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) self.scaler.step(self.optimizer) self.scaler.update() else: _compute_pred_loss() engine.state.output[self.keys["LOSS"]].backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) self.optimizer.step() engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output
def test_compare(self, a, b, p, current, expected=True): """Test pytorch_after with a and b""" self.assertEqual(pytorch_after(a, b, p, current), expected)