示例#1
0
def compute_vote_loss(input, output: VoteNetResults):
    """ Compute vote loss: Match predicted votes to GT votes.

    Args:
        end_points: dict (read-only)

    Returns:
        vote_loss: scalar Tensor

    Overall idea:
        If the seed point belongs to an object (votes_label_mask == 1),
        then we require it to vote for the object center.

        Each seed point may vote for multiple translations v1,v2,v3
        A seed point may also be in the boxes of multiple objects:
        o1,o2,o3 with corresponding GT votes c1,c2,c3

        Then the loss for this seed point is:
            min(d(v_i,c_j)) for i=1,2,3 and j=1,2,3
    """

    # Load ground truth votes and assign them to seed points
    batch_size = output["seed_pos"].shape[0]
    num_seed = output["seed_pos"].shape[1]  # B,num_seed,3
    vote_xyz = output["seed_votes"]  # B,num_seed*vote_factor,3
    seed_inds = output["seed_inds"].long()  # B,num_seed in [0,num_points-1]

    # Get groundtruth votes for the seed points
    # vote_label_mask: Use gather to select B,num_seed from B,num_point
    #   non-object point has no GT vote mask = 0, object point has mask = 1
    # vote_label: Use gather to select B,num_seed,9 from B,num_point,9
    #   with inds in shape B,num_seed,9 and 9 = GT_VOTE_FACTOR * 3
    if seed_inds.dim() == 1:
        seed_gt_votes_mask = torch.gather(input["vote_label_mask"], 0, seed_inds).view((batch_size, -1))
        seed_gt_votes = torch.gather(input["vote_label"], 0, seed_inds.unsqueeze(-1).repeat(1, 3 * GT_VOTE_FACTOR))
        seed_gt_votes += output["seed_pos"].view((-1, 3)).repeat((1, 3))
    else:
        seed_gt_votes_mask = torch.gather(input["vote_label_mask"], 1, seed_inds)
        seed_inds_expand = seed_inds.view(batch_size, num_seed, 1).repeat(1, 1, 3 * GT_VOTE_FACTOR)
        seed_gt_votes = torch.gather(input["vote_label"], 1, seed_inds_expand)
        seed_gt_votes += output["seed_pos"].repeat(1, 1, 3)

    # Compute the min of min of distance
    vote_xyz_reshape = vote_xyz.view(
        batch_size * num_seed, -1, 3
    )  # from B,num_seed*vote_factor,3 to B*num_seed,vote_factor,3
    seed_gt_votes_reshape = seed_gt_votes.view(
        batch_size * num_seed, GT_VOTE_FACTOR, 3
    )  # from B,num_seed,3*GT_VOTE_FACTOR to B*num_seed,GT_VOTE_FACTOR,3
    # A predicted vote to no where is not penalized as long as there is a good vote near the GT vote.
    dist1, _, dist2, _ = nn_distance(vote_xyz_reshape, seed_gt_votes_reshape, l1=True)
    votes_dist, _ = torch.min(dist2, dim=1)  # (B*num_seed,vote_factor) to (B*num_seed,)
    votes_dist = votes_dist.view(batch_size, num_seed)
    vote_loss = torch.sum(votes_dist * seed_gt_votes_mask.float()) / (torch.sum(seed_gt_votes_mask.float()) + 1e-6)
    return vote_loss
示例#2
0
def compute_box_and_sem_cls_loss(inputs, outputs, loss_params):
    """ Compute 3D bounding box and semantic classification loss.

    Args:
        end_points: dict (read-only)

    Returns:
        center_loss
        heading_cls_loss
        heading_reg_loss
        size_cls_loss
        size_reg_loss
        sem_cls_loss
    """

    num_heading_bin = loss_params.num_heading_bin
    mean_size_arr = np.asarray(loss_params.mean_size_arr)
    num_size_cluster = len(mean_size_arr)

    object_assignment = outputs.object_assignment
    batch_size = object_assignment.shape[0]

    # Compute center loss
    pred_center = outputs["center"]
    gt_center = inputs["gt_center"]
    dist1, ind1, dist2, _ = nn_distance(pred_center,
                                        gt_center)  # dist1: BxK, dist2: BxK2
    box_label_mask = inputs["box_label_mask"]
    objectness_label = outputs["objectness_label"].float()
    centroid_reg_loss1 = torch.sum(
        dist1 * objectness_label) / (torch.sum(objectness_label) + 1e-6)
    centroid_reg_loss2 = torch.sum(
        dist2 * box_label_mask) / (torch.sum(box_label_mask) + 1e-6)
    center_loss = centroid_reg_loss1 + centroid_reg_loss2

    # Compute heading loss
    heading_class_label = torch.gather(
        inputs["heading_class_label"], 1,
        object_assignment)  # select (B,K) from (B,K2)
    criterion_heading_class = nn.CrossEntropyLoss(reduction="none")

    heading_class_loss = criterion_heading_class(
        outputs["heading_scores"].transpose(2, 1),
        heading_class_label.long())  # (B,K)
    heading_class_loss = torch.sum(heading_class_loss * objectness_label) / (
        torch.sum(objectness_label) + 1e-6)

    heading_residual_label = torch.gather(
        inputs["heading_residual_label"], 1,
        object_assignment)  # select (B,K) from (B,K2)
    heading_residual_normalized_label = heading_residual_label / (
        np.pi / num_heading_bin)

    # Ref: https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507/3
    heading_label_one_hot = torch.zeros(batch_size,
                                        heading_class_label.shape[1],
                                        num_heading_bin).to(inputs.pos.device)
    heading_label_one_hot.scatter_(
        2,
        heading_class_label.unsqueeze(-1).long(), 1
    )  # src==1 so it's *one-hot* (B,K,num_heading_bin) TODO change that for pytorch OneHot
    heading_residual_normalized_loss = huber_loss(
        torch.sum(
            outputs["heading_residuals_normalized"] * heading_label_one_hot,
            -1) - heading_residual_normalized_label,
        delta=1.0,
    )  # (B,K)
    heading_residual_normalized_loss = torch.sum(
        heading_residual_normalized_loss *
        objectness_label) / (torch.sum(objectness_label) + 1e-6)

    # Compute size loss
    size_class_label = torch.gather(
        inputs["size_class_label"], 1,
        object_assignment)  # select (B,K) from (B,K2)
    criterion_size_class = nn.CrossEntropyLoss(reduction="none")
    if num_size_cluster != 0:
        size_class_loss = criterion_size_class(
            outputs["size_scores"].transpose(2, 1),
            size_class_label.long())  # (B,K)
        size_class_loss = torch.sum(size_class_loss * objectness_label) / (
            torch.sum(objectness_label) + 1e-6)

        size_residual_label = torch.gather(
            inputs["size_residual_label"], 1,
            object_assignment.unsqueeze(-1).repeat(
                1, 1, 3))  # select (B,K,3) from (B,K2,3)

        size_label_one_hot = torch.zeros(batch_size, size_class_label.shape[1],
                                         num_size_cluster).to(
                                             inputs.pos.device)
        size_label_one_hot.scatter_(
            2,
            size_class_label.unsqueeze(-1).long(),
            1)  # src==1 so it's *one-hot* (B,K,num_size_cluster)
        size_label_one_hot_tiled = size_label_one_hot.unsqueeze(-1).repeat(
            1, 1, 1, 3)  # (B,K,num_size_cluster,3)
        predicted_size_residual_normalized = torch.sum(
            outputs["size_residuals_normalized"] * size_label_one_hot_tiled,
            2)  # (B,K,3)

        mean_size_arr_expanded = (torch.from_numpy(
            mean_size_arr.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(
                inputs.pos.device))  # (1,1,num_size_cluster,3)
        mean_size_label = torch.sum(size_label_one_hot_tiled *
                                    mean_size_arr_expanded, 2)  # (B,K,3)
        size_residual_label_normalized = size_residual_label / mean_size_label  # (B,K,3)
        size_residual_normalized_loss = torch.mean(
            huber_loss(predicted_size_residual_normalized -
                       size_residual_label_normalized,
                       delta=1.0), -1)  # (B,K,3) -> (B,K)
        size_residual_normalized_loss = torch.sum(
            size_residual_normalized_loss *
            objectness_label) / (torch.sum(objectness_label) + 1e-6)
    else:
        size_class_loss = 0
        size_residual_normalized_loss = 0

    # 3.4 Semantic cls loss
    sem_cls_label = torch.gather(inputs["sem_cls_label"], 1,
                                 object_assignment)  # select (B,K) from (B,K2)
    criterion_sem_cls = nn.CrossEntropyLoss(reduction="none")
    sem_cls_loss = criterion_sem_cls(outputs["sem_cls_scores"].transpose(2, 1),
                                     sem_cls_label.long())  # (B,K)
    sem_cls_loss = torch.sum(
        sem_cls_loss * objectness_label) / (torch.sum(objectness_label) + 1e-6)

    return (
        center_loss,
        heading_class_loss,
        heading_residual_normalized_loss,
        size_class_loss,
        size_residual_normalized_loss,
        sem_cls_loss,
    )