Ejemplo n.º 1
0
def _focal_loss(pred, gt):
    '''Modified focal loss. Exactly the same as CornerNet.

    Modified for more stability by using log_sigmoid function

      Arguments:
        pred (batch x c x h x w): logit (must be values before sigmoid activation)
        gt_regr (batch x c x h x w)
    '''
    alpha = 2
    beta = 4
    pos_inds = F.greater_equal_scalar(gt, 1)
    neg_inds = 1 - pos_inds
    neg_weights = F.pow_scalar(1.0 - gt, beta)
    prob_pred = F.sigmoid(pred)
    pos_loss = F.log_sigmoid(pred) * F.pow_scalar(1.0 - prob_pred,
                                                  alpha) * pos_inds
    pos_loss = F.sum(pos_loss)
    neg_loss = F.log_sigmoid(-pred) * F.pow_scalar(
        prob_pred, alpha) * neg_weights * neg_inds
    neg_loss = F.sum(neg_loss)
    num_pos = F.maximum_scalar(F.sum(pos_inds), 1)
    loss = -(1 / num_pos) * (pos_loss + neg_loss)
    return loss
Ejemplo n.º 2
0
 def _loss_minus(self, dout):
     return -F.log_sigmoid(-dout)
Ejemplo n.º 3
0
 def _loss_plus(self, dout):
     return -F.log_sigmoid(dout)