def forward(self, pred, label, sample_weight=None):
        one_hot = label > 0.5
        sample_weight = label != self._ignore_label

        if not self._from_logits:
            pred = torch.sigmoid(pred)

        alpha = torch.where(one_hot, self._alpha * sample_weight,
                            (1 - self._alpha) * sample_weight)
        pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred),
                         torch.ones_like(pred))

        beta = (1 - pt)**self._gamma

        loss = -alpha * beta * torch.log(
            torch.min(pt + self._eps,
                      torch.ones(1, dtype=torch.float).to(pt.device)))
        loss = self._weight * (loss * sample_weight)

        if self._size_average:
            tsum = torch.sum(sample_weight,
                             dim=misc.get_dims_with_exclusion(
                                 label.dim(), self._batch_axis))
            loss = torch.sum(loss,
                             dim=misc.get_dims_with_exclusion(
                                 loss.dim(),
                                 self._batch_axis)) / (tsum + self._eps)
        else:
            loss = torch.sum(loss,
                             dim=misc.get_dims_with_exclusion(
                                 loss.dim(), self._batch_axis))

        return self._scale * loss
示例#2
0
    def forward(self, pred, label, sample_weight=None):
        if not self._from_logits:
            pred = F.sigmoid(pred)

        one_hot = label > 0
        pt = torch.where(one_hot, pred, 1 - pred)

        t = label != -1
        alpha = torch.where(one_hot, self._alpha * t, (1 - self._alpha) * t)
        beta = (1 - pt)**self._gamma

        loss = -alpha * beta * torch.log(
            torch.min(pt + self._eps,
                      torch.ones(1, dtype=torch.float).to(pt.device)))
        sample_weight = label != -1

        loss = self._weight * (loss * sample_weight)

        if self._size_average:
            tsum = torch.sum(label == 1,
                             dim=misc.get_dims_with_exclusion(
                                 label.dim(), self._batch_axis))
            loss = torch.sum(loss,
                             dim=misc.get_dims_with_exclusion(
                                 loss.dim(),
                                 self._batch_axis)) / (tsum + self._eps)
        else:
            loss = torch.sum(loss,
                             dim=misc.get_dims_with_exclusion(
                                 loss.dim(), self._batch_axis))

        return self._scale * loss
    def forward(self, pred, label):
        one_hot = label > 0.5
        sample_weight = label != self._ignore_label

        if not self._from_logits:
            pred = torch.sigmoid(pred)

        alpha = torch.where(one_hot, self._alpha * sample_weight,
                            (1 - self._alpha) * sample_weight)
        pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred),
                         torch.ones_like(pred))

        beta = (1 - pt)**self._gamma

        sw_sum = torch.sum(sample_weight, dim=(-2, -1), keepdim=True)
        beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True)
        mult = sw_sum / (beta_sum + self._eps)
        if self._detach_delimeter:
            mult = mult.detach()
        beta = beta * mult
        if self._max_mult > 0:
            beta = torch.clamp_max(beta, self._max_mult)

        with torch.no_grad():
            ignore_area = torch.sum(label == self._ignore_label,
                                    dim=tuple(range(
                                        1, label.dim()))).cpu().numpy()
            sample_mult = torch.mean(mult, dim=tuple(range(
                1, mult.dim()))).cpu().numpy()
            if np.any(ignore_area == 0):
                self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[
                    ignore_area == 0].mean()

                beta_pmax, _ = torch.flatten(beta, start_dim=1).max(dim=1)
                beta_pmax = beta_pmax.mean().item()
                self._m_max = 0.8 * self._m_max + 0.2 * beta_pmax

        loss = -alpha * beta * torch.log(
            torch.min(pt + self._eps,
                      torch.ones(1, dtype=torch.float).to(pt.device)))
        loss = self._weight * (loss * sample_weight)

        if self._size_average:
            bsum = torch.sum(sample_weight,
                             dim=misc.get_dims_with_exclusion(
                                 sample_weight.dim(), self._batch_axis))
            loss = torch.sum(loss,
                             dim=misc.get_dims_with_exclusion(
                                 loss.dim(),
                                 self._batch_axis)) / (bsum + self._eps)
        else:
            loss = torch.sum(loss,
                             dim=misc.get_dims_with_exclusion(
                                 loss.dim(), self._batch_axis))

        return loss
示例#4
0
    def forward(self, pred, label, sample_weight=None):
        one_hot = label > 0
        sample_weight = label != self._ignore_label

        if not self._from_logits:
            pred = torch.sigmoid(pred)

        alpha = torch.where(one_hot, self._alpha * sample_weight,
                            (1 - self._alpha) * sample_weight)
        pt = torch.where(one_hot, pred, 1 - pred)
        pt = torch.where(sample_weight, pt, torch.ones_like(pt))

        beta = (1 - pt)**self._gamma

        sw_sum = torch.sum(sample_weight, dim=(-2, -1), keepdim=True)
        beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True)
        mult = sw_sum / (beta_sum + self._eps)
        if self._detach_delimeter:
            mult = mult.detach()
        beta = beta * mult

        ignore_area = torch.sum(label == self._ignore_label,
                                dim=tuple(range(1,
                                                label.dim()))).cpu().numpy()
        sample_mult = torch.mean(mult,
                                 dim=tuple(range(1,
                                                 mult.dim()))).cpu().numpy()
        if np.any(ignore_area == 0):
            self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[ignore_area ==
                                                                0].mean()

        loss = -alpha * beta * torch.log(
            torch.min(pt + self._eps,
                      torch.ones(1, dtype=torch.float).to(pt.device)))
        loss = self._weight * (loss * sample_weight)

        if self._size_average:
            bsum = torch.sum(sample_weight,
                             dim=misc.get_dims_with_exclusion(
                                 sample_weight.dim(), self._batch_axis))
            loss = torch.sum(loss,
                             dim=misc.get_dims_with_exclusion(
                                 loss.dim(),
                                 self._batch_axis)) / (bsum + self._eps)
        else:
            loss = torch.sum(loss,
                             dim=misc.get_dims_with_exclusion(
                                 loss.dim(), self._batch_axis))

        return self._scale * loss
def _compute_iou(pred_mask, gt_mask, ignore_mask=None, keep_ignore=False):
    if ignore_mask is not None:
        pred_mask = torch.where(ignore_mask, torch.zeros_like(pred_mask), pred_mask)

    reduction_dims = misc.get_dims_with_exclusion(gt_mask.dim(), 0)
    union = torch.mean((pred_mask | gt_mask).float(), dim=reduction_dims).detach().cpu().numpy()
    intersection = torch.mean((pred_mask & gt_mask).float(), dim=reduction_dims).detach().cpu().numpy()
    nonzero = union > 0

    iou = intersection[nonzero] / union[nonzero]
    if not keep_ignore:
        return iou
    else:
        result = np.full_like(intersection, -1)
        result[nonzero] = iou
        return result
    def forward(self, pred, label):
        label = label.view(pred.size())
        sample_weight = label != self._ignore_label
        label = torch.where(sample_weight, label, torch.zeros_like(label))

        if not self._from_sigmoid:
            loss = torch.relu(pred) - pred * label + F.softplus(
                -torch.abs(pred))
        else:
            eps = 1e-12
            loss = -(torch.log(pred + eps) * label +
                     torch.log(1. - pred + eps) * (1. - label))

        loss = self._weight * (loss * sample_weight)
        return torch.mean(loss,
                          dim=misc.get_dims_with_exclusion(
                              loss.dim(), self._batch_axis))