def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore """ Args: y_pred: input data to compute, typical segmentation model output. It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values should be binarized. y: ground truth to compute the distance. It must be one-hot format and first dim is batch. The values should be binarized. Raises: ValueError: when `y` is not a binarized tensor. ValueError: when `y_pred` has less than three dimensions. """ is_binary_tensor(y_pred, "y_pred") is_binary_tensor(y, "y") if y_pred.dim() < 3: raise ValueError("y_pred should have at least three dimensions.") # compute (BxC) for each channel for each batch return compute_average_surface_distance( y_pred=y_pred, y=y, include_background=self.include_background, symmetric=self.symmetric, distance_metric=self.distance_metric, )
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore """ Args: y_pred: input data to compute, typical segmentation model output. It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values should be binarized. y: ground truth to compute mean IoU metric. It must be one-hot format and first dim is batch. The values should be binarized. Raises: ValueError: when `y` is not a binarized tensor. ValueError: when `y_pred` has less than three dimensions. """ is_binary_tensor(y_pred, "y_pred") is_binary_tensor(y, "y") dims = y_pred.ndimension() if dims < 3: raise ValueError( f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}." ) # compute IoU (BxC) for each channel for each batch return compute_meaniou(y_pred=y_pred, y=y, include_background=self.include_background, ignore_empty=self.ignore_empty)
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore """ Args: y_pred: input data to compute. It must be one-hot format and first dim is batch. The values should be binarized. y: ground truth to compute the metric. It must be one-hot format and first dim is batch. The values should be binarized. Raises: ValueError: when `y` is not a binarized tensor. ValueError: when `y_pred` has less than two dimensions. """ is_binary_tensor(y_pred, "y_pred") is_binary_tensor(y, "y") # check dimension dims = y_pred.ndimension() if dims < 2: raise ValueError("y_pred should have at least two dimensions.") if dims == 2 or (dims == 3 and y_pred.shape[-1] == 1): if self.compute_sample: warnings.warn( "As for classification task, compute_sample should be False." ) self.compute_sample = False return get_confusion_matrix(y_pred=y_pred, y=y, include_background=self.include_background)