def __call__(self, inputs: torch.Tensor, network: Callable[..., torch.Tensor], *args: Any, **kwargs: Any) -> torch.Tensor: """ Args: inputs: model input data for inference. network: target model to execute inference. supports callables such as ``lambda x: my_torch_model(x, additional_config)`` args: optional args to be passed to ``network``. kwargs: optional keyword args to be passed to ``network``. """ return sliding_window_inference( inputs, self.roi_size, self.sw_batch_size, network, self.overlap, self.mode, self.sigma_scale, self.padding_mode, self.cval, self.sw_device, self.device, self.progress, *args, **kwargs, )
def __call__(self, inputs: torch.Tensor, network): """ Unified callable function API of Inferers. Args: inputs (torch.tensor): model input data for inference. network (Network): target model to execute inference. """ # resize the input to the appropriate network input orig_size = list(inputs.shape) resized_size = copy.deepcopy(orig_size) resized_size[2] = self.roi_size[0] resized_size[3] = self.roi_size[1] inputs_resize = torch.nn.functional.interpolate(inputs, size=resized_size[2:], mode='trilinear') # convert the network to a callable that squeezes 3D slices to 2D before performing the network prediction predictor_2d = Predict2DFrom3D(network) outputs = sliding_window_inference(inputs_resize, self.roi_size, self.sw_batch_size, predictor_2d, self.overlap, self.mode) # resize back to original size outputs = torch.nn.functional.interpolate(outputs, size=orig_size[2:], mode='nearest') return outputs
def __call__(self, inputs: torch.Tensor, network): """ Unified callable function API of Inferers. Args: inputs (torch.tensor): model input data for inference. network (Network): target model to execute inference. """ return sliding_window_inference(inputs, self.roi_size, self.sw_batch_size, network, self.overlap, self.mode)
def __call__(self, inputs: torch.Tensor, network): """ Unified callable function API of Inferers. Args: inputs (torch.tensor): model input data for inference. network (Network): target model to execute inference. """ # convert the network to a callable that squeezes 3D slices to 2D before performing the network prediction predictor_2d = Predict2DFrom3D(network) return sliding_window_inference(inputs, self.roi_size, self.sw_batch_size, predictor_2d, self.overlap, self.mode)
def __call__(self, inputs: torch.Tensor, network: Callable[[torch.Tensor], torch.Tensor]) -> torch.Tensor: """ Args: inputs: model input data for inference. network: target model to execute inference. supports callables such as ``lambda x: my_torch_model(x, additional_config)`` """ return sliding_window_inference( inputs=inputs, roi_size=self.roi_size, sw_batch_size=self.sw_batch_size, predictor=network, overlap=self.overlap, mode=self.mode, sigma_scale=self.sigma_scale, padding_mode=self.padding_mode, cval=self.cval, )
def __call__(self, inputs: torch.Tensor, network: torch.nn.Module) -> torch.Tensor: """ Unified callable function API of Inferers. Args: inputs: model input data for inference. network: target model to execute inference. """ return sliding_window_inference( inputs=inputs, roi_size=self.roi_size, sw_batch_size=self.sw_batch_size, predictor=network, overlap=self.overlap, mode=self.mode, sigma_scale=self.sigma_scale, padding_mode=self.padding_mode, cval=self.cval, )