def forward(self, pred, label, sample_weight=None): one_hot = label > 0.5 sample_weight = label != self._ignore_label if not self._from_logits: pred = torch.sigmoid(pred) alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight) pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred)) beta = (1 - pt)**self._gamma loss = -alpha * beta * torch.log( torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device))) loss = self._weight * (loss * sample_weight) if self._size_average: tsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion( label.dim(), self._batch_axis)) loss = torch.sum(loss, dim=misc.get_dims_with_exclusion( loss.dim(), self._batch_axis)) / (tsum + self._eps) else: loss = torch.sum(loss, dim=misc.get_dims_with_exclusion( loss.dim(), self._batch_axis)) return self._scale * loss
def forward(self, pred, label, sample_weight=None): if not self._from_logits: pred = F.sigmoid(pred) one_hot = label > 0 pt = torch.where(one_hot, pred, 1 - pred) t = label != -1 alpha = torch.where(one_hot, self._alpha * t, (1 - self._alpha) * t) beta = (1 - pt)**self._gamma loss = -alpha * beta * torch.log( torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device))) sample_weight = label != -1 loss = self._weight * (loss * sample_weight) if self._size_average: tsum = torch.sum(label == 1, dim=misc.get_dims_with_exclusion( label.dim(), self._batch_axis)) loss = torch.sum(loss, dim=misc.get_dims_with_exclusion( loss.dim(), self._batch_axis)) / (tsum + self._eps) else: loss = torch.sum(loss, dim=misc.get_dims_with_exclusion( loss.dim(), self._batch_axis)) return self._scale * loss
def forward(self, pred, label): one_hot = label > 0.5 sample_weight = label != self._ignore_label if not self._from_logits: pred = torch.sigmoid(pred) alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight) pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred)) beta = (1 - pt)**self._gamma sw_sum = torch.sum(sample_weight, dim=(-2, -1), keepdim=True) beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True) mult = sw_sum / (beta_sum + self._eps) if self._detach_delimeter: mult = mult.detach() beta = beta * mult if self._max_mult > 0: beta = torch.clamp_max(beta, self._max_mult) with torch.no_grad(): ignore_area = torch.sum(label == self._ignore_label, dim=tuple(range( 1, label.dim()))).cpu().numpy() sample_mult = torch.mean(mult, dim=tuple(range( 1, mult.dim()))).cpu().numpy() if np.any(ignore_area == 0): self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[ ignore_area == 0].mean() beta_pmax, _ = torch.flatten(beta, start_dim=1).max(dim=1) beta_pmax = beta_pmax.mean().item() self._m_max = 0.8 * self._m_max + 0.2 * beta_pmax loss = -alpha * beta * torch.log( torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device))) loss = self._weight * (loss * sample_weight) if self._size_average: bsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion( sample_weight.dim(), self._batch_axis)) loss = torch.sum(loss, dim=misc.get_dims_with_exclusion( loss.dim(), self._batch_axis)) / (bsum + self._eps) else: loss = torch.sum(loss, dim=misc.get_dims_with_exclusion( loss.dim(), self._batch_axis)) return loss
def forward(self, pred, label, sample_weight=None): one_hot = label > 0 sample_weight = label != self._ignore_label if not self._from_logits: pred = torch.sigmoid(pred) alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight) pt = torch.where(one_hot, pred, 1 - pred) pt = torch.where(sample_weight, pt, torch.ones_like(pt)) beta = (1 - pt)**self._gamma sw_sum = torch.sum(sample_weight, dim=(-2, -1), keepdim=True) beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True) mult = sw_sum / (beta_sum + self._eps) if self._detach_delimeter: mult = mult.detach() beta = beta * mult ignore_area = torch.sum(label == self._ignore_label, dim=tuple(range(1, label.dim()))).cpu().numpy() sample_mult = torch.mean(mult, dim=tuple(range(1, mult.dim()))).cpu().numpy() if np.any(ignore_area == 0): self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean() loss = -alpha * beta * torch.log( torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device))) loss = self._weight * (loss * sample_weight) if self._size_average: bsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion( sample_weight.dim(), self._batch_axis)) loss = torch.sum(loss, dim=misc.get_dims_with_exclusion( loss.dim(), self._batch_axis)) / (bsum + self._eps) else: loss = torch.sum(loss, dim=misc.get_dims_with_exclusion( loss.dim(), self._batch_axis)) return self._scale * loss
def _compute_iou(pred_mask, gt_mask, ignore_mask=None, keep_ignore=False): if ignore_mask is not None: pred_mask = torch.where(ignore_mask, torch.zeros_like(pred_mask), pred_mask) reduction_dims = misc.get_dims_with_exclusion(gt_mask.dim(), 0) union = torch.mean((pred_mask | gt_mask).float(), dim=reduction_dims).detach().cpu().numpy() intersection = torch.mean((pred_mask & gt_mask).float(), dim=reduction_dims).detach().cpu().numpy() nonzero = union > 0 iou = intersection[nonzero] / union[nonzero] if not keep_ignore: return iou else: result = np.full_like(intersection, -1) result[nonzero] = iou return result
def forward(self, pred, label): label = label.view(pred.size()) sample_weight = label != self._ignore_label label = torch.where(sample_weight, label, torch.zeros_like(label)) if not self._from_sigmoid: loss = torch.relu(pred) - pred * label + F.softplus( -torch.abs(pred)) else: eps = 1e-12 loss = -(torch.log(pred + eps) * label + torch.log(1. - pred + eps) * (1. - label)) loss = self._weight * (loss * sample_weight) return torch.mean(loss, dim=misc.get_dims_with_exclusion( loss.dim(), self._batch_axis))