Esempio n. 1
0
    def __call__(self, inputs: torch.Tensor, network: Callable[...,
                                                               torch.Tensor],
                 *args: Any, **kwargs: Any) -> torch.Tensor:
        """

        Args:
            inputs: model input data for inference.
            network: target model to execute inference.
                supports callables such as ``lambda x: my_torch_model(x, additional_config)``
            args: optional args to be passed to ``network``.
            kwargs: optional keyword args to be passed to ``network``.

        """
        return sliding_window_inference(
            inputs,
            self.roi_size,
            self.sw_batch_size,
            network,
            self.overlap,
            self.mode,
            self.sigma_scale,
            self.padding_mode,
            self.cval,
            self.sw_device,
            self.device,
            self.progress,
            *args,
            **kwargs,
        )
Esempio n. 2
0
    def __call__(self, inputs: torch.Tensor, network):
        """
        Unified callable function API of Inferers.

        Args:
            inputs (torch.tensor): model input data for inference.
            network (Network): target model to execute inference.

        """
        # resize the input to the appropriate network input
        orig_size = list(inputs.shape)
        resized_size = copy.deepcopy(orig_size)
        resized_size[2] = self.roi_size[0]
        resized_size[3] = self.roi_size[1]
        inputs_resize = torch.nn.functional.interpolate(inputs,
                                                        size=resized_size[2:],
                                                        mode='trilinear')

        # convert the network to a callable that squeezes 3D slices to 2D before performing the network prediction
        predictor_2d = Predict2DFrom3D(network)
        outputs = sliding_window_inference(inputs_resize, self.roi_size,
                                           self.sw_batch_size, predictor_2d,
                                           self.overlap, self.mode)

        # resize back to original size
        outputs = torch.nn.functional.interpolate(outputs,
                                                  size=orig_size[2:],
                                                  mode='nearest')
        return outputs
Esempio n. 3
0
    def __call__(self, inputs: torch.Tensor, network):
        """
        Unified callable function API of Inferers.

        Args:
            inputs (torch.tensor): model input data for inference.
            network (Network): target model to execute inference.

        """
        return sliding_window_inference(inputs, self.roi_size,
                                        self.sw_batch_size, network,
                                        self.overlap, self.mode)
Esempio n. 4
0
    def __call__(self, inputs: torch.Tensor, network):
        """
        Unified callable function API of Inferers.

        Args:
            inputs (torch.tensor): model input data for inference.
            network (Network): target model to execute inference.

        """
        # convert the network to a callable that squeezes 3D slices to 2D before performing the network prediction
        predictor_2d = Predict2DFrom3D(network)
        return sliding_window_inference(inputs, self.roi_size,
                                        self.sw_batch_size, predictor_2d,
                                        self.overlap, self.mode)
Esempio n. 5
0
    def __call__(self, inputs: torch.Tensor, network: Callable[[torch.Tensor], torch.Tensor]) -> torch.Tensor:
        """

        Args:
            inputs: model input data for inference.
            network: target model to execute inference.
                supports callables such as ``lambda x: my_torch_model(x, additional_config)``

        """
        return sliding_window_inference(
            inputs=inputs,
            roi_size=self.roi_size,
            sw_batch_size=self.sw_batch_size,
            predictor=network,
            overlap=self.overlap,
            mode=self.mode,
            sigma_scale=self.sigma_scale,
            padding_mode=self.padding_mode,
            cval=self.cval,
        )
Esempio n. 6
0
    def __call__(self, inputs: torch.Tensor,
                 network: torch.nn.Module) -> torch.Tensor:
        """
        Unified callable function API of Inferers.

        Args:
            inputs: model input data for inference.
            network: target model to execute inference.

        """
        return sliding_window_inference(
            inputs=inputs,
            roi_size=self.roi_size,
            sw_batch_size=self.sw_batch_size,
            predictor=network,
            overlap=self.overlap,
            mode=self.mode,
            sigma_scale=self.sigma_scale,
            padding_mode=self.padding_mode,
            cval=self.cval,
        )