예제 #1
0
def get_allLoss(op, # Network output
                elOut, # Network ellipse regression output
                target, # Segmentation targets
                pupil_center, # Pupil center
                elNorm, # Normalized ellipse parameters
                spatWts,
                distMap,
                cond, # Condition matrix, 0 represents modality exists
                ID,
                alpha):

    B, C, H, W = op.shape
    loc_onlyMask = (1 -cond[:,1]).to(torch.float32) # GT mask present (True means mask exist)
    loc_onlyMask.requires_grad = False # Ensure no accidental backprop
    
    # Segmentation to pupil center loss using center of mass
    l_seg2pt_pup, pred_c_seg_pup = get_seg2ptLoss(op[:, 2, ...],
                                                  normPts(pupil_center,
                                                  target.shape[1:]), temperature=4)
    
    # Segmentation to iris center loss using center of mass
    if torch.sum(loc_onlyMask):
        # Iris center is only present when GT masks are present. Note that
        # elNorm will hold garbage values. Those samples should not be backprop
        iriMap = -op[:, 0, ...] # Inverse of background mask
        l_seg2pt_iri, pred_c_seg_iri = get_seg2ptLoss(iriMap,
                                                      elNorm[:, 0, :2],
                                                      temperature=4)
        temp = torch.stack([loc_onlyMask, loc_onlyMask], dim=1)
        l_seg2pt_iri = torch.sum(l_seg2pt_iri*temp)/torch.sum(temp.to(torch.float32))
        l_seg2pt_pup = torch.mean(l_seg2pt_pup)
        
    else:
        # If GT map is absent, loss is set to 0.0
        # Set Iris and Pupil center to be same
        l_seg2pt_iri = 0.0
        l_seg2pt_pup = torch.mean(l_seg2pt_pup)
        pred_c_seg_iri = torch.clone(elOut[:, 5:7])

    pred_c_seg = torch.stack([pred_c_seg_iri,
                              pred_c_seg_pup], dim=1) # Iris first policy
    l_seg2pt = 0.5*l_seg2pt_pup + 0.5*l_seg2pt_iri

    # Segmentation loss -> backbone loss
    l_seg = get_segLoss(op, target, spatWts, distMap, loc_onlyMask, alpha)

    # Bottleneck ellipse losses
    # NOTE: This loss is only activated when normalized ellipses do not exist
    l_pt = get_ptLoss(elOut[:, 5:7], normPts(pupil_center,
                                             target.shape[1:]), 1-loc_onlyMask)
    
    # Compute ellipse losses - F1 loss for valid samples
    l_ellipse = get_ptLoss(elOut, elNorm.view(-1, 10), loc_onlyMask)

    total_loss = l_seg2pt + 20*l_seg + 10*(l_pt + l_ellipse)

    return (total_loss, pred_c_seg)
예제 #2
0
def get_allLoss(
        op,  # Network output
        target,  # Segmentation targets
        pupil_center,  # Pupil center
        cond,  # Condition
):
    B = op.shape[0]
    log_seg_ok = 1 - cond[:, 1]
    target = (target == 2).to(torch.long)  # 1 for pupil rest 0
    l_seg2pt_pup, pred_c_seg_pup = get_seg2ptLoss(op[:, 1, ...],
                                                  normPts(
                                                      pupil_center,
                                                      target.shape[1:]),
                                                  temperature=4)
    l_seg = 10 * F.cross_entropy(
        torch.softmax(op, dim=1), target, reduction='none')
    if torch.sum(log_seg_ok):
        l_seg = torch.sum(l_seg.reshape(B, -1).mean(dim=1) *
                          log_seg_ok) / torch.sum(log_seg_ok)
    else:
        l_seg = 0.0
    loss = l_seg + torch.mean(l_seg2pt_pup)
    return loss, pred_c_seg_pup