def compute_meandice(y_pred, y, include_background=False, to_onehot_y=True, mutually_exclusive=True, add_sigmoid=False, logit_thresh=None): """Computes dice score metric from full size Tensor and collects average. Args: y_pred (torch.Tensor): input data to compute, typical segmentation model output. it must be One-Hot format and first dim is batch, example shape: [16, 3, 32, 32]. y (torch.Tensor): ground truth to compute mean dice metric, the first dim is batch. include_background (Bool): whether to skip dice computation on the first channel of the predicted output. to_onehot_y (Bool): whether to convert `y` into the one-hot format. mutually_exclusive (Bool): if True, `y_pred` will be converted into a binary matrix using a combination of argmax and to_onehot. add_sigmoid (Bool): whether to add sigmoid function to y_pred before computation. logit_thresh (Float): the threshold value used to convert `y_pred` into a binary matrix. Note: This method provide two options to convert `y_pred` into a binary matrix: (1) when `mutually_exclusive` is True, it uses a combination of argmax and to_onehot (2) when `mutually_exclusive` is False, it uses a threshold `logit_thresh` (optionally with a sigmoid function before thresholding). """ n_channels_y_pred = y_pred.shape[1] if mutually_exclusive: if logit_thresh is not None: raise ValueError( '`logit_thresh` is incompatible when mutually_exclusive is True.' ) y_pred = torch.argmax(y_pred, dim=1, keepdim=True) y_pred = one_hot(y_pred, n_channels_y_pred) else: # channel-wise thresholding if add_sigmoid: y_pred = torch.sigmoid(y_pred) if logit_thresh is not None: y_pred = (y_pred >= logit_thresh).float() if to_onehot_y: y = one_hot(y, n_channels_y_pred) if not include_background: y = y[:, 1:] if y.shape[1] > 1 else y y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred # reducing only spatial dimensions (not batch nor channels) reduce_axis = list(range(2, y_pred.dim())) intersection = torch.sum(y * y_pred, reduce_axis) y_o = torch.sum(y, reduce_axis) y_pred_o = torch.sum(y_pred, reduce_axis) denominator = y_o + y_pred_o f = (2.0 * intersection) / denominator # final reduce_mean across batches and channels return torch.mean(f)
def forward(self, input: torch.Tensor, target: torch.Tensor, smooth: float = 1e-5): """ Args: input (tensor): the shape should be BNH[WD]. target (tensor): the shape should be BNH[WD]. smooth: a small constant to avoid nan. """ if self.sigmoid: input = torch.sigmoid(input) n_pred_ch = input.shape[1] if n_pred_ch == 1: if self.softmax: warnings.warn("single channel prediction, `softmax=True` ignored.") if self.to_onehot_y: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") if not self.include_background: warnings.warn("single channel prediction, `include_background=False` ignored.") else: if self.softmax: input = torch.softmax(input, 1) if self.to_onehot_y: target = one_hot(target, n_pred_ch) if not self.include_background: # if skipping background, removing first channel target = target[:, 1:] input = input[:, 1:] assert ( target.shape == input.shape ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})" # reducing only spatial dimensions (not batch nor channels) reduce_axis = list(range(2, len(input.shape))) intersection = torch.sum(target * input, reduce_axis) ground_o = torch.sum(target, reduce_axis) pred_o = torch.sum(input, reduce_axis) denominator = ground_o + pred_o w = self.w_func(ground_o.float()) for b in w: infs = torch.isinf(b) b[infs] = 0.0 b[infs] = torch.max(b) f = 1.0 - (2.0 * (intersection * w).sum(1) + smooth) / ((denominator * w).sum(1) + smooth) if self.reduction == "mean": f = torch.mean(f) # the batch and channel average elif self.reduction == "sum": f = torch.sum(f) # sum over the batch and channel dims elif self.reduction == "none": pass # returns [N, n_classes] losses else: raise ValueError(f"reduction={self.reduction} is invalid.") return f
def forward(self, input: torch.Tensor, target: torch.Tensor, smooth: float = 1e-5): """ Args: input (tensor): the shape should be BNH[WD]. target (tensor): the shape should be BNH[WD]. smooth (float): a small constant to avoid nan. """ if self.do_sigmoid: input = torch.sigmoid(input) n_pred_ch = input.shape[1] if n_pred_ch == 1: if self.do_softmax: warnings.warn("single channel prediction, `do_softmax=True` ignored.") if self.to_onehot_y: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") if not self.include_background: warnings.warn("single channel prediction, `include_background=False` ignored.") else: if self.do_softmax: input = torch.softmax(input, 1) if self.to_onehot_y: target = one_hot(target, n_pred_ch) if not self.include_background: # if skipping background, removing first channel target = target[:, 1:] input = input[:, 1:] assert ( target.shape == input.shape ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})" p0 = input p1 = 1 - p0 g0 = target g1 = 1 - g0 # reducing only spatial dimensions (not batch nor channels) reduce_axis = list(range(2, len(input.shape))) tp = torch.sum(p0 * g0, reduce_axis) fp = self.alpha * torch.sum(p0 * g1, reduce_axis) fn = self.beta * torch.sum(p1 * g0, reduce_axis) numerator = tp + smooth denominator = tp + fp + fn + smooth score = 1.0 - numerator / denominator if self.reduction == "sum": return score.sum() # sum over the batch and channel dims if self.reduction == "none": return score # returns [N, n_classes] losses if self.reduction == "mean": return score.mean() raise ValueError(f"reduction={self.reduction} is invalid.")
def test_shape(self, input_data, expected_shape, expected_result=None): result = one_hot(**input_data) self.assertEqual(result.shape, expected_shape) if expected_result is not None: self.assertTrue(np.allclose(expected_result, result.numpy())) if "dtype" in input_data: self.assertEqual(result.dtype, input_data["dtype"]) else: # by default, expecting float type self.assertEqual(result.dtype, torch.float)
def forward(self, input, target, smooth=1e-5): """ Args: input (tensor): the shape should be BNH[WD]. target (tensor): the shape should be BNH[WD]. smooth (float): a small constant to avoid nan. """ if self.do_sigmoid: input = torch.sigmoid(input) n_pred_ch = input.shape[1] if n_pred_ch == 1: if self.do_softmax: warnings.warn("single channel prediction, `do_softmax=True` ignored.") if self.to_onehot_y: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") if not self.include_background: warnings.warn("single channel prediction, `include_background=False` ignored.") else: if self.do_softmax: input = torch.softmax(input, 1) if self.to_onehot_y: target = one_hot(target, n_pred_ch) if not self.include_background: # if skipping background, removing first channel target = target[:, 1:] input = input[:, 1:] assert ( target.shape == input.shape ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})" # reducing only spatial dimensions (not batch nor channels) reduce_axis = list(range(2, len(input.shape))) intersection = torch.sum(target * input, reduce_axis) if self.squared_pred: target = torch.pow(target, 2) input = torch.pow(input, 2) ground_o = torch.sum(target, reduce_axis) pred_o = torch.sum(input, reduce_axis) denominator = ground_o + pred_o if self.jaccard: denominator -= intersection f = 1.0 - (2.0 * intersection + smooth) / (denominator + smooth) if self.reduction == "sum": return f.sum() # sum over the batch and channel dims if self.reduction == "none": return f # returns [N, n_classes] losses if self.reduction == "mean": return f.mean() # the batch and channel average raise ValueError(f"reduction={self.reduction} is invalid.")
def forward(self, pred, ground, smooth=1e-5): """ Args: pred (tensor): the shape should be BNH[WD]. ground (tensor): the shape should be BNH[WD]. smooth (float): a small constant to avoid nan. """ if self.do_sigmoid: pred = torch.sigmoid(pred) n_pred_ch = pred.shape[1] if n_pred_ch == 1: if self.do_softmax: warnings.warn( "single channel prediction, `do_softmax=True` ignored.") if self.to_onehot_y: warnings.warn( "single channel prediction, `to_onehot_y=True` ignored.") if not self.include_background: warnings.warn( "single channel prediction, `include_background=False` ignored." ) else: if self.do_softmax: pred = torch.softmax(pred, 1) if self.to_onehot_y: ground = one_hot(ground, n_pred_ch) if not self.include_background: # if skipping background, removing first channel ground = ground[:, 1:] pred = pred[:, 1:] assert ground.shape == pred.shape, "ground truth one-hot has differing shape (%r) from pred (%r)" % ( ground.shape, pred.shape, ) # reducing only spatial dimensions (not batch nor channels) reduce_axis = list(range(2, len(pred.shape))) intersection = torch.sum(ground * pred, reduce_axis) if self.squared_pred: ground = torch.pow(ground, 2) pred = torch.pow(pred, 2) ground_o = torch.sum(ground, reduce_axis) pred_o = torch.sum(pred, reduce_axis) denominator = ground_o + pred_o if self.jaccard: denominator -= intersection f = (2.0 * intersection + smooth) / (denominator + smooth) return 1.0 - f.mean() # final reduce_mean across batches and channels
def forward(self, pred, ground, smooth=1e-5): """ Args: pred (tensor): the shape should be BNH[WD]. ground (tensor): the shape should be B1H[WD]. smooth (float): a small constant to avoid nan. """ if ground.shape[1] != 1: raise ValueError( "Ground truth should have only a single channel, shape is " + str(ground.shape)) psum = pred.float() if self.do_sigmoid: psum = psum.sigmoid() # use sigmoid activation if pred.shape[1] == 1: if self.do_softmax: raise ValueError( 'do_softmax is not compatible with single channel prediction.' ) if not self.include_background: warnings.warn( 'single channel prediction, `include_background=False` ignored.' ) tsum = ground else: # multiclass dice loss if self.do_softmax: psum = torch.softmax(pred, 1) tsum = one_hot(ground, pred.shape[1]) # B1HW(D) -> BNHW(D) # exclude background category so that it doesn't overwhelm the other segmentations if they are small if not self.include_background: tsum = tsum[:, 1:] psum = psum[:, 1:] assert tsum.shape == psum.shape, ( "Ground truth one-hot has differing shape (%r) from source (%r)" % (tsum.shape, psum.shape)) batchsize, n_classes = tsum.shape[:2] tsum = tsum.float().view(batchsize, n_classes, -1) psum = psum.view(batchsize, n_classes, -1) intersection = psum * tsum sums = psum + tsum w = self.w_func(tsum.sum(2)) for b in w: infs = torch.isinf(b) b[infs] = 0.0 b[infs] = torch.max(b) score = (2.0 * intersection.sum(2) * w + smooth) / (sums.sum(2) * w + smooth) return 1 - score.mean()
def forward(self, pred, ground, smooth=1e-5): """ Args: pred (tensor): the shape should be BNH[WD]. ground (tensor): the shape should be BNH[WD]. smooth (float): a small constant to avoid nan. """ if self.do_sigmoid: pred = torch.sigmoid(pred) n_pred_ch = pred.shape[1] if n_pred_ch == 1: if self.do_softmax: warnings.warn( 'single channel prediction, `do_softmax=True` ignored.') if self.to_onehot_y: warnings.warn( 'single channel prediction, `to_onehot_y=True` ignored.') if not self.include_background: warnings.warn( 'single channel prediction, `include_background=False` ignored.' ) else: if self.do_softmax: pred = torch.softmax(pred, 1) if self.to_onehot_y: ground = one_hot(ground, n_pred_ch) if not self.include_background: # if skipping background, removing first channel ground = ground[:, 1:] pred = pred[:, 1:] assert ground.shape == pred.shape, ( 'ground truth one-hot has differing shape (%r) from pred (%r)' % (ground.shape, pred.shape)) p0 = pred p1 = 1 - p0 g0 = ground g1 = 1 - g0 # reducing only spatial dimensions (not batch nor channels) reduce_axis = list(range(2, len(pred.shape))) tp = torch.sum(p0 * g0, reduce_axis) fp = self.alpha * torch.sum(p0 * g1, reduce_axis) fn = self.beta * torch.sum(p1 * g0, reduce_axis) numerator = tp + smooth denominator = tp + fp + fn + smooth score = numerator / denominator return 1.0 - score.mean()
def __call__(self, img, to_onehot=None, num_classes=None): if to_onehot or self.to_onehot: if num_classes is None: num_classes = self.num_classes assert isinstance(num_classes, int), "must specify class number for One-Hot." img = one_hot(img, num_classes) n_classes = img.shape[1] outputs = list() for i in range(n_classes): outputs.append(img[:, i:i + 1]) return outputs
def __call__(self, img, to_onehot: Optional[bool] = None, num_classes: Optional[int] = None ): # type: ignore # see issue #495 if to_onehot or self.to_onehot: if num_classes is None: num_classes = self.num_classes assert isinstance(num_classes, int), "must specify class number for One-Hot." img = one_hot(img, num_classes) n_classes = img.shape[1] outputs = list() for i in range(n_classes): outputs.append(img[:, i:i + 1]) return outputs
def forward(self, pred, ground, smooth=1e-5): if ground.shape[1] != 1: raise ValueError( "Ground truth should have only a single channel, shape is " + str(ground.shape)) psum = pred.float() if self.do_sigmoid: psum = psum.sigmoid() # use sigmoid activation if pred.shape[1] == 1: if self.do_softmax: raise ValueError( 'do_softmax is not compatible with single channel prediction.' ) if not self.include_background: raise RuntimeWarning( 'single channel prediction, `include_background=False` ignored.' ) tsum = ground else: # multiclass dice loss if self.do_softmax: if self.do_sigmoid: raise ValueError( 'do_sigmoid=True and do_softmax=Ture are not compatible.' ) psum = torch.softmax(pred, 1) tsum = one_hot(ground, pred.shape[1]) # B1HW(D) -> BNHW(D) # exclude background category so that it doesn't overwhelm the other segmentations if they are small if not self.include_background: tsum = tsum[:, 1:] psum = psum[:, 1:] assert tsum.shape == psum.shape, ( "Ground truth one-hot has differing shape (%r) from source (%r)" % (tsum.shape, psum.shape)) batchsize = ground.size(0) tsum = tsum.float().view(batchsize, -1) psum = psum.view(batchsize, -1) intersection = psum * tsum sums = psum + tsum score = 2.0 * (intersection.sum(1) + smooth) / (sums.sum(1) + smooth) return 1 - score.sum() / batchsize
def __call__(self, img, argmax=None, to_onehot=None, n_classes=None, threshold_values=None, logit_thresh=None): if argmax or self.argmax: img = torch.argmax(img, dim=1, keepdim=True) if to_onehot or self.to_onehot: img = one_hot(img, self.n_classes if n_classes is None else n_classes) if threshold_values or self.threshold_values: img = img >= (self.logit_thresh if logit_thresh is None else logit_thresh) return img.float()
def __call__( # type: ignore # see issue #495 self, img, argmax: Optional[bool] = None, to_onehot: Optional[bool] = None, n_classes: Optional[int] = None, threshold_values: Optional[bool] = None, logit_thresh: Optional[float] = None, ): if argmax or self.argmax: img = torch.argmax(img, dim=1, keepdim=True) if to_onehot or self.to_onehot: _nclasses = self.n_classes if n_classes is None else n_classes assert isinstance( _nclasses, int), "One of self.n_classes or n_classes must be an integer" img = one_hot(img, _nclasses) if threshold_values or self.threshold_values: img = img >= (self.logit_thresh if logit_thresh is None else logit_thresh) return img.float()
def compute_roc_auc( y_pred: torch.Tensor, y: torch.Tensor, to_onehot_y: bool = False, softmax: bool = False, average: Optional[str] = "macro", ): """Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to: `sklearn.metrics.roc_auc_score <https://scikit-learn.org/stable/modules/generated/ sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score>`_. Args: y_pred (torch.Tensor): input data to compute, typical classification model output. it must be One-Hot format and first dim is batch, example shape: [16] or [16, 2]. y (torch.Tensor): ground truth to compute ROC AUC metric, the first dim is batch. example shape: [16, 1] will be converted into [16, 2] (where `2` is inferred from `y_pred`). to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. softmax: whether to add softmax function to `y_pred` before computation. Defaults to False. average (`macro|weighted|micro|None`): type of averaging performed if not binary classification. Default is 'macro'. - 'macro': calculate metrics for each label, and find their unweighted mean. this does not take label imbalance into account. - 'weighted': calculate metrics for each label, and find their average, weighted by support (the number of true instances for each label). - 'micro': calculate metrics globally by considering each element of the label indicator matrix as a label. - None: the scores for each class are returned. Note: ROCAUC expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values. """ y_pred_ndim = y_pred.ndimension() y_ndim = y.ndimension() if y_pred_ndim not in (1, 2): raise ValueError( "predictions should be of shape (batch_size, n_classes) or (batch_size, )." ) if y_ndim not in (1, 2): raise ValueError( "targets should be of shape (batch_size, n_classes) or (batch_size, )." ) if y_pred_ndim == 2 and y_pred.shape[1] == 1: y_pred = y_pred.squeeze(dim=-1) y_pred_ndim = 1 if y_ndim == 2 and y.shape[1] == 1: y = y.squeeze(dim=-1) if y_pred_ndim == 1: if to_onehot_y: warnings.warn( "y_pred has only one channel, to_onehot_y=True ignored.") if softmax: warnings.warn("y_pred has only one channel, softmax=True ignored.") return _calculate(y, y_pred) else: n_classes = y_pred.shape[1] if to_onehot_y: y = one_hot(y, n_classes) if softmax: y_pred = y_pred.float().softmax(dim=1) assert y.shape == y_pred.shape, "data shapes of y_pred and y do not match." if average == "micro": return _calculate(y.flatten(), y_pred.flatten()) else: y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1) auc_values = [ _calculate(y_, y_pred_) for y_, y_pred_ in zip(y, y_pred) ] if average is None: return auc_values if average == "macro": return np.mean(auc_values) if average == "weighted": weights = [sum(y_) for y_ in y] return np.average(auc_values, weights=weights) raise ValueError("unsupported average method.")
def compute_meandice(y_pred, y, include_background=True, to_onehot_y=True, mutually_exclusive=False, add_sigmoid=False, logit_thresh=0.5): """Computes dice score metric from full size Tensor and collects average. Args: y_pred (torch.Tensor): input data to compute, typical segmentation model output. it must be One-Hot format and first dim is batch, example shape: [16, 3, 32, 32]. y (torch.Tensor): ground truth to compute mean dice metric, the first dim is batch. example shape: [16, 1, 32, 32] will be converted into [16, 3, 32, 32]. alternative shape: [16, 3, 32, 32] and set `to_onehot_y=False` to use 3-class labels directly. include_background (Bool): whether to skip Dice computation on the first channel of the predicted output. Defaults to True. to_onehot_y (Bool): whether to convert `y` into the one-hot format. Defaults to True. mutually_exclusive (Bool): if True, `y_pred` will be converted into a binary matrix using a combination of argmax and to_onehot. Defaults to False. add_sigmoid (Bool): whether to add sigmoid function to y_pred before computation. Defaults to False. logit_thresh (Float): the threshold value used to convert (after sigmoid if `add_sigmoid=True`) `y_pred` into a binary matrix. Defaults to 0.5. Returns: Dice scores per batch and per class (shape: [batch_size, n_classes]). Note: This method provides two options to convert `y_pred` into a binary matrix (1) when `mutually_exclusive` is True, it uses a combination of ``argmax`` and ``to_onehot``, (2) when `mutually_exclusive` is False, it uses a threshold ``logit_thresh`` (optionally with a ``sigmoid`` function before thresholding). """ n_classes = y_pred.shape[1] n_len = len(y_pred.shape) if add_sigmoid: y_pred = y_pred.float().sigmoid() if n_classes == 1: if mutually_exclusive: warnings.warn('y_pred has only one class, mutually_exclusive=True ignored.') if to_onehot_y: warnings.warn('y_pred has only one channel, to_onehot_y=True ignored.') if not include_background: warnings.warn('y_pred has only one channel, include_background=False ignored.') # make both y and y_pred binary y_pred = (y_pred >= logit_thresh).float() y = (y > 0).float() else: # multi-channel y_pred # make both y and y_pred binary if mutually_exclusive: if add_sigmoid: raise ValueError('add_sigmoid=True is incompatible with mutually_exclusive=True.') y_pred = torch.argmax(y_pred, dim=1, keepdim=True) y_pred = one_hot(y_pred, n_classes) else: y_pred = (y_pred >= logit_thresh).float() if to_onehot_y: y = one_hot(y, n_classes) if not include_background: y = y[:, 1:] if y.shape[1] > 1 else y y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred assert y.shape == y_pred.shape, ("Ground truth one-hot has differing shape (%r) from source (%r)" % (y.shape, y_pred.shape)) # reducing only spatial dimensions (not batch nor channels) reduce_axis = list(range(2, n_len)) intersection = torch.sum(y * y_pred, reduce_axis) y_o = torch.sum(y, reduce_axis) y_pred_o = torch.sum(y_pred, reduce_axis) denominator = y_o + y_pred_o f = torch.where(y_o > 0, (2.0 * intersection) / denominator, torch.tensor(float('nan')).to(y_o.float())) return f # returns array of Dice shape: [Batch, n_classes]
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: input: the shape should be BNH[WD]. target: the shape should be BNH[WD] Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. """ if self.sigmoid: input = torch.sigmoid(input) n_pred_ch = input.shape[1] if self.softmax: if n_pred_ch == 1: warnings.warn( "single channel prediction, `softmax=True` ignored.") else: input = torch.softmax(input, 1) if self.other_act is not None: input = self.other_act(input) if self.to_onehot_y: if n_pred_ch == 1: warnings.warn( "single channel prediction, `to_onehot_y=True` ignored.") else: target = one_hot(target, num_classes=n_pred_ch) if not self.include_background: if n_pred_ch == 1: warnings.warn( "single channel prediction, `include_background=False` ignored." ) else: # if skipping background, removing first channel target = target[:, 1:] input = input[:, 1:] assert ( target.shape == input.shape ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})" if self.batch_version: # reducing only spatial dimensions and batch (not channels) reduce_axis = [0] + list(range(2, len(input.shape))) else: # reducing only spatial dimensions (not batch nor channels) reduce_axis = list(range(2, len(input.shape))) intersection = torch.sum(target * input, dim=reduce_axis) if self.squared_pred: target = torch.pow(target, 2) input = torch.pow(input, 2) ground_o = torch.sum(target, dim=reduce_axis) pred_o = torch.sum(input, dim=reduce_axis) denominator = ground_o + pred_o if self.jaccard: denominator = 2.0 * (denominator - intersection) f: torch.Tensor = (1.0 - (2.0 * intersection + self.smooth_num) / (denominator + self.smooth_den))**self.pow if self.reduction == LossReduction.MEAN.value: f = torch.mean(f) # the batch and channel average elif self.reduction == LossReduction.SUM.value: f = torch.sum(f) # sum over the batch and channel dims elif self.reduction == LossReduction.NONE.value: pass # returns [N, n_classes] losses or [n_classes] if batch version else: raise ValueError( f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].' ) return f
def _test_epoch(self, epoch): start = time.time() logs = {} loss_meter = AverageMetricTracker() metric_trackers = {m[1]: AverageMetricTracker() for m in self.metrics} self.model.eval() for batch_idx, (x, y) in enumerate(self.test_loader): x = x.to(self.device) y = y.to(self.device) with torch.no_grad(): # make predictions out1, out2 = self.model.forward(x) classes = class_count(y, self.num_classes) # calculate weighted loss clf_loss = self.criterion(out2, classes) sgm_loss = self.loss(out1, y) loss = ((1.0 - self.loss_weight) * clf_loss) + (self.loss_weight * sgm_loss) loss_value = loss.cpu().detach().numpy() loss_meter.add(loss_value) loss_logs = {'loss': loss_meter.mean} logs.update(loss_logs) # neptune logging (valid step) neptune.log_metric('test_loss_step', loss_value) for m in self.metrics: metric = m[0] # unpack metric class metric_name = m[1] # unpack metric name metric_type = m[2] # unpack metric type if metric_type == 'classification': metric_value = metric(out2, classes).cpu().detach().numpy() metric_trackers[metric_name].add(metric_value) elif metric_type == 'segmentation': d = out1.get_device() if out1.is_cuda else 'cpu' if not self.binary: y_pred = torch.argmax(out1, dim=1, keepdim=True).to(d) y_pred = one_hot(y_pred, num_classes=out1.shape[1]) else: y_pred = torch.round(out1).to(d) # calculate metric metric_value = metric(y_pred, y) metric_value = metric_value[0][0] if isinstance(metric_value, tuple) else metric_value metric_value = metric_value.cpu().detach().numpy() metric_trackers[metric_name].add(metric_value) else: raise ValueError(f'Type {metric_type} is not a valid metric type.') # neptune logging (valid step) neptune.log_metric('test_' + metric_name + '_step', metric_value) metrics_logs = {k: v.mean for k, v in metric_trackers.items()} logs.update(metrics_logs) # neptune logging (valid epoch) for k, v in logs.items(): neptune.log_metric('test_' + k + '_epoch', v) duration = time.time() - start if self.verbose: self._show_progress(duration, logs, stage='Test') return logs
def test_shape(self, input_data, expected_shape, expected_result=None): result = one_hot(**input_data) self.assertEqual(result.shape, expected_shape) if expected_result is not None: self.assertTrue(np.allclose(expected_result, result.numpy()))
def forward(self, pred, gt): """ Input: - pred: the output from model (before softmax) shape (N, C, H, W) - gt: ground truth map shape (N, 1, H, w) Return: - boundary loss, averaged over mini-batch """ n, c, _, _ = pred.shape # softmax so that predicted map can be distributed in [0, 1] pred = torch.softmax(pred, dim=1) # one-hot vector of ground truth one_hot_gt = one_hot(gt, c) # boundary map gt_b = F.max_pool2d(1 - one_hot_gt, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2) gt_b -= 1 - one_hot_gt pred_b = F.max_pool2d(1 - pred, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2) pred_b -= 1 - pred # extended boundary map gt_b_ext = F.max_pool2d(gt_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2) pred_b_ext = F.max_pool2d(pred_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2) # # to check hyper-parameter # idx= 0 # print('boundary_loss') # print(torch.unique(gt_b),torch.unique(gt_b_ext)) # plt.figure(figsize=(24,8)) # plt.subplot(231);plt.title('gt');plt.imshow(gt[idx,0].cpu().detach().numpy()) # plt.subplot(232);plt.title('gt_boundary');plt.imshow(gt_b[idx,0].cpu().detach().numpy()) # plt.subplot(233);plt.title('gt_boundary_ext');plt.imshow(gt_b_ext[0,idx].cpu().detach().numpy()) # plt.subplot(234);plt.title('pred');plt.imshow(pred[idx,1].cpu().detach().numpy()) # plt.subplot(235);plt.title('pred_boundary');plt.imshow(pred_b[idx,0].cpu().detach().numpy()) # plt.subplot(236);plt.title('pred_boundary_ext');plt.imshow(pred_b_ext[idx,0].cpu().detach().numpy()) # plt.show() # reshape gt_b = gt_b.view(n, c, -1) pred_b = pred_b.view(n, c, -1) gt_b_ext = gt_b_ext.view(n, c, -1) pred_b_ext = pred_b_ext.view(n, c, -1) smooth = 1e-7 # original impliment # Precision, Recall P = torch.sum(pred_b * gt_b_ext, dim=2) / (torch.sum(pred_b, dim=2) + smooth) R = torch.sum(pred_b_ext * gt_b, dim=2) / (torch.sum(gt_b, dim=2) + smooth) # Boundary F1 Score smooth = 1e-7 BF1 = (2 * P * R) / (P + R + smooth) # BF1 = (2 * self.alpha * (1-self.alpha) * P * R + smooth) / (self.alpha*P + (1-self.alpha)*R + smooth) # summing BF1 Score for each class and average over mini-batch # loss = torch.mean(1 - BF1) loss = torch.mean(torch.pow(1 - BF1, self.gamma)) return loss