def _focal_loss(pred, gt): '''Modified focal loss. Exactly the same as CornerNet. Modified for more stability by using log_sigmoid function Arguments: pred (batch x c x h x w): logit (must be values before sigmoid activation) gt_regr (batch x c x h x w) ''' alpha = 2 beta = 4 pos_inds = F.greater_equal_scalar(gt, 1) neg_inds = 1 - pos_inds neg_weights = F.pow_scalar(1.0 - gt, beta) prob_pred = F.sigmoid(pred) pos_loss = F.log_sigmoid(pred) * F.pow_scalar(1.0 - prob_pred, alpha) * pos_inds pos_loss = F.sum(pos_loss) neg_loss = F.log_sigmoid(-pred) * F.pow_scalar( prob_pred, alpha) * neg_weights * neg_inds neg_loss = F.sum(neg_loss) num_pos = F.maximum_scalar(F.sum(pos_inds), 1) loss = -(1 / num_pos) * (pos_loss + neg_loss) return loss
def _loss_minus(self, dout): return -F.log_sigmoid(-dout)
def _loss_plus(self, dout): return -F.log_sigmoid(dout)