def get_allLoss(op, # Network output elOut, # Network ellipse regression output target, # Segmentation targets pupil_center, # Pupil center elNorm, # Normalized ellipse parameters spatWts, distMap, cond, # Condition matrix, 0 represents modality exists ID, alpha): B, C, H, W = op.shape loc_onlyMask = (1 -cond[:,1]).to(torch.float32) # GT mask present (True means mask exist) loc_onlyMask.requires_grad = False # Ensure no accidental backprop # Segmentation to pupil center loss using center of mass l_seg2pt_pup, pred_c_seg_pup = get_seg2ptLoss(op[:, 2, ...], normPts(pupil_center, target.shape[1:]), temperature=4) # Segmentation to iris center loss using center of mass if torch.sum(loc_onlyMask): # Iris center is only present when GT masks are present. Note that # elNorm will hold garbage values. Those samples should not be backprop iriMap = -op[:, 0, ...] # Inverse of background mask l_seg2pt_iri, pred_c_seg_iri = get_seg2ptLoss(iriMap, elNorm[:, 0, :2], temperature=4) temp = torch.stack([loc_onlyMask, loc_onlyMask], dim=1) l_seg2pt_iri = torch.sum(l_seg2pt_iri*temp)/torch.sum(temp.to(torch.float32)) l_seg2pt_pup = torch.mean(l_seg2pt_pup) else: # If GT map is absent, loss is set to 0.0 # Set Iris and Pupil center to be same l_seg2pt_iri = 0.0 l_seg2pt_pup = torch.mean(l_seg2pt_pup) pred_c_seg_iri = torch.clone(elOut[:, 5:7]) pred_c_seg = torch.stack([pred_c_seg_iri, pred_c_seg_pup], dim=1) # Iris first policy l_seg2pt = 0.5*l_seg2pt_pup + 0.5*l_seg2pt_iri # Segmentation loss -> backbone loss l_seg = get_segLoss(op, target, spatWts, distMap, loc_onlyMask, alpha) # Bottleneck ellipse losses # NOTE: This loss is only activated when normalized ellipses do not exist l_pt = get_ptLoss(elOut[:, 5:7], normPts(pupil_center, target.shape[1:]), 1-loc_onlyMask) # Compute ellipse losses - F1 loss for valid samples l_ellipse = get_ptLoss(elOut, elNorm.view(-1, 10), loc_onlyMask) total_loss = l_seg2pt + 20*l_seg + 10*(l_pt + l_ellipse) return (total_loss, pred_c_seg)
def get_allLoss( op, # Network output target, # Segmentation targets pupil_center, # Pupil center cond, # Condition ): B = op.shape[0] log_seg_ok = 1 - cond[:, 1] target = (target == 2).to(torch.long) # 1 for pupil rest 0 l_seg2pt_pup, pred_c_seg_pup = get_seg2ptLoss(op[:, 1, ...], normPts( pupil_center, target.shape[1:]), temperature=4) l_seg = 10 * F.cross_entropy( torch.softmax(op, dim=1), target, reduction='none') if torch.sum(log_seg_ok): l_seg = torch.sum(l_seg.reshape(B, -1).mean(dim=1) * log_seg_ok) / torch.sum(log_seg_ok) else: l_seg = 0.0 loss = l_seg + torch.mean(l_seg2pt_pup) return loss, pred_c_seg_pup