예제 #1
0
def predict(args):
    print("predict bounding boxes...")
    # constant
    DC = ScannetDatasetConfig()

    # init training dataset
    print("preparing data...")
    scanrefer, scene_list = get_scanrefer(args)

    # dataloader
    _, dataloader = get_dataloader(args, scanrefer, scene_list, "test", DC)

    # model
    model = get_model(args, DC)

    # config
    POST_DICT = {
        "remove_empty_box": True, 
        "use_3d_nms": True, 
        "nms_iou": 0.25,
        "use_old_type_nms": False, 
        "cls_nms": True, 
        "per_class_proposal": True,
        "conf_thresh": 0.05,
        "dataset_config": DC
    } if not args.no_nms else None

    # predict
    print("predicting...")
    pred_bboxes = []
    for data_dict in tqdm(dataloader):
        for key in data_dict:
            data_dict[key] = data_dict[key].cuda()

        # feed
        data_dict = model(data_dict)
        _, data_dict = get_loss(
            data_dict=data_dict, 
            config=DC, 
            detection=False,
            reference=True
        )

        objectness_preds_batch = torch.argmax(data_dict['objectness_scores'], 2).long()

        if POST_DICT:
            _ = parse_predictions(data_dict, POST_DICT)
            nms_masks = torch.LongTensor(data_dict['pred_mask']).cuda()

            # construct valid mask
            pred_masks = (nms_masks * objectness_preds_batch == 1).float()
        else:
            # construct valid mask
            pred_masks = (objectness_preds_batch == 1).float()

        pred_ref = torch.argmax(data_dict['cluster_ref'] * pred_masks, 1) # (B,)
        pred_center = data_dict['center'] # (B,K,3)
        pred_heading_class = torch.argmax(data_dict['heading_scores'], -1) # B,num_proposal
        pred_heading_residual = torch.gather(data_dict['heading_residuals'], 2, pred_heading_class.unsqueeze(-1)) # B,num_proposal,1
        pred_heading_class = pred_heading_class # B,num_proposal
        pred_heading_residual = pred_heading_residual.squeeze(2) # B,num_proposal
        pred_size_class = torch.argmax(data_dict['size_scores'], -1) # B,num_proposal
        pred_size_residual = torch.gather(data_dict['size_residuals'], 2, pred_size_class.unsqueeze(-1).unsqueeze(-1).repeat(1,1,1,3)) # B,num_proposal,1,3
        pred_size_class = pred_size_class
        pred_size_residual = pred_size_residual.squeeze(2) # B,num_proposal,3

        for i in range(pred_ref.shape[0]):
            # compute the iou
            pred_ref_idx = pred_ref[i]
            pred_obb = DC.param2obb(
                pred_center[i, pred_ref_idx, 0:3].detach().cpu().numpy(), 
                pred_heading_class[i, pred_ref_idx].detach().cpu().numpy(), 
                pred_heading_residual[i, pred_ref_idx].detach().cpu().numpy(),
                pred_size_class[i, pred_ref_idx].detach().cpu().numpy(), 
                pred_size_residual[i, pred_ref_idx].detach().cpu().numpy()
            )
            pred_bbox = get_3d_box(pred_obb[3:6], pred_obb[6], pred_obb[0:3])

            # construct the multiple mask
            multiple = data_dict["unique_multiple"][i].item()

            # construct the others mask
            others = 1 if data_dict["object_cat"][i] == 17 else 0

            # store data
            scanrefer_idx = data_dict["scan_idx"][i].item()
            pred_data = {
                "scene_id": scanrefer[scanrefer_idx]["scene_id"],
                "object_id": scanrefer[scanrefer_idx]["object_id"],
                "ann_id": scanrefer[scanrefer_idx]["ann_id"],
                "bbox": pred_bbox.tolist(),
                "unique_multiple": multiple,
                "others": others
            }
            pred_bboxes.append(pred_data)

    # dump
    print("dumping...")
    pred_path = os.path.join(CONF.PATH.OUTPUT, args.folder, "pred.json")
    with open(pred_path, "w") as f:
        json.dump(pred_bboxes, f, indent=4)

    print("done!")
예제 #2
0
def feed_scene_cap(model,
                   device,
                   dataset,
                   dataloader,
                   phase,
                   folder,
                   use_tf=False,
                   is_eval=True,
                   max_len=CONF.TRAIN.MAX_DES_LEN,
                   save_interm=False,
                   min_iou=CONF.TRAIN.MIN_IOU_THRESHOLD):
    candidates = {}
    intermediates = {}
    for data_dict in tqdm(dataloader):
        # move to cuda
        for key in data_dict:
            data_dict[key] = data_dict[key].cuda()

        with torch.no_grad():
            data_dict = model(data_dict, use_tf=use_tf, is_eval=is_eval)
            data_dict = get_scene_cap_loss(data_dict,
                                           device,
                                           DC,
                                           weights=dataset.weights,
                                           detection=True,
                                           caption=False)

        # unpack
        captions = data_dict["lang_cap"].argmax(
            -1)  # batch_size, num_proposals, max_len - 1
        dataset_ids = data_dict["dataset_idx"]
        batch_size, num_proposals, _ = captions.shape

        # post-process
        # config
        POST_DICT = {
            "remove_empty_box": True,
            "use_3d_nms": True,
            "nms_iou": 0.25,
            "use_old_type_nms": False,
            "cls_nms": True,
            "per_class_proposal": True,
            "conf_thresh": 0.05,
            "dataset_config": DC
        }

        # nms mask
        _ = parse_predictions(data_dict, POST_DICT)
        nms_masks = torch.LongTensor(data_dict["pred_mask"]).cuda()

        # objectness mask
        obj_masks = torch.argmax(data_dict["objectness_scores"], 2).long()

        # final mask
        nms_masks = nms_masks * obj_masks

        # pick out object ids of detected objects
        detected_object_ids = torch.gather(data_dict["scene_object_ids"], 1,
                                           data_dict["object_assignment"])

        # bbox corners
        assigned_target_bbox_corners = torch.gather(
            data_dict["gt_box_corner_label"], 1,
            data_dict["object_assignment"].view(
                batch_size, num_proposals, 1,
                1).repeat(1, 1, 8, 3))  # batch_size, num_proposals, 8, 3
        detected_bbox_corners = data_dict[
            "bbox_corner"]  # batch_size, num_proposals, 8, 3
        detected_bbox_centers = data_dict[
            "center"]  # batch_size, num_proposals, 3

        # compute IoU between each detected box and each ground truth box
        ious = box3d_iou_batch_tensor(
            assigned_target_bbox_corners.view(
                -1, 8, 3),  # batch_size * num_proposals, 8, 3
            detected_bbox_corners.view(-1, 8,
                                       3)  # batch_size * num_proposals, 8, 3
        ).view(batch_size, num_proposals)

        # find good boxes (IoU > threshold)
        good_bbox_masks = ious > min_iou  # batch_size, num_proposals

        # dump generated captions
        object_attn_masks = {}
        for batch_id in range(batch_size):
            dataset_idx = dataset_ids[batch_id].item()
            scene_id = dataset.scanrefer[dataset_idx]["scene_id"]
            object_attn_masks[scene_id] = np.zeros(
                (num_proposals, num_proposals))
            for prop_id in range(num_proposals):
                if nms_masks[batch_id,
                             prop_id] == 1 and good_bbox_masks[batch_id,
                                                               prop_id] == 1:
                    object_id = str(detected_object_ids[batch_id,
                                                        prop_id].item())
                    caption_decoded = decode_caption(
                        captions[batch_id, prop_id],
                        dataset.vocabulary["idx2word"])

                    # print(scene_id, object_id)
                    try:
                        ann_list = list(
                            SCANREFER_ORGANIZED[scene_id][object_id].keys())
                        object_name = SCANREFER_ORGANIZED[scene_id][object_id][
                            ann_list[0]]["object_name"]

                        # store
                        key = "{}|{}|{}".format(scene_id, object_id,
                                                object_name)
                        # key = "{}|{}".format(scene_id, object_id)
                        candidates[key] = [caption_decoded]

                        if save_interm:
                            if scene_id not in intermediates:
                                intermediates[scene_id] = {}
                            if object_id not in intermediates[scene_id]:
                                intermediates[scene_id][object_id] = {}

                            intermediates[scene_id][object_id][
                                "object_name"] = object_name
                            intermediates[scene_id][object_id][
                                "box_corner"] = detected_bbox_corners[
                                    batch_id, prop_id].cpu().numpy().tolist()
                            intermediates[scene_id][object_id][
                                "description"] = caption_decoded
                            intermediates[scene_id][object_id][
                                "token"] = caption_decoded.split(" ")

                            # attention context
                            # extract attention masks for each object
                            object_attn_weights = data_dict[
                                "topdown_attn"][:, :, :
                                                num_proposals]  # NOTE only consider attention on objects
                            valid_context_masks = data_dict[
                                "valid_masks"][:, :, :
                                               num_proposals]  # NOTE only consider attention on objects

                            cur_valid_context_masks = valid_context_masks[
                                batch_id, prop_id]  # num_proposals
                            cur_context_box_corners = detected_bbox_corners[
                                batch_id,
                                cur_valid_context_masks == 1]  # X, 8, 3
                            cur_object_attn_weights = object_attn_weights[
                                batch_id, prop_id,
                                cur_valid_context_masks == 1]  # X

                            intermediates[scene_id][object_id][
                                "object_attn_weight"] = cur_object_attn_weights.cpu(
                                ).numpy().T.tolist()
                            intermediates[scene_id][object_id][
                                "object_attn_context"] = cur_context_box_corners.cpu(
                                ).numpy().tolist()

                        # cache
                        object_attn_masks[scene_id][prop_id, prop_id] = 1
                    except KeyError:
                        continue

    # detected boxes
    if save_interm:
        print("saving intermediate results...")
        interm_path = os.path.join(CONF.PATH.OUTPUT, folder, "interm.json")
        with open(interm_path, "w") as f:
            json.dump(intermediates, f, indent=4)

    return candidates
예제 #3
0
def eval_detection(args):
    print("evaluate detection...")
    # constant
    DC = ScannetDatasetConfig()

    # init training dataset
    print("preparing data...")
    # get eval data
    scanrefer_eval, eval_scene_list = get_eval_data(args)

    # get dataloader
    dataset, dataloader = get_dataloader(args, scanrefer_eval, eval_scene_list,
                                         DC)

    # model
    print("initializing...")
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    root = CONF.PATH.PRETRAINED if args.eval_pretrained else CONF.PATH.OUTPUT
    model = get_model(args, dataset, device, root)

    # config
    POST_DICT = {
        "remove_empty_box": True,
        "use_3d_nms": True,
        "nms_iou": 0.25,
        "use_old_type_nms": False,
        "cls_nms": True,
        "per_class_proposal": True,
        "conf_thresh": 0.05,
        "dataset_config": DC
    }
    AP_IOU_THRESHOLDS = [0.25, 0.5]
    AP_CALCULATOR_LIST = [
        APCalculator(iou_thresh, DC.class2type)
        for iou_thresh in AP_IOU_THRESHOLDS
    ]

    sem_acc = []
    for data in tqdm(dataloader):
        for key in data:
            data[key] = data[key].cuda()

        # feed
        with torch.no_grad():
            data = model(data, use_tf=False, is_eval=True)
            data = get_scene_cap_loss(data,
                                      device,
                                      DC,
                                      weights=dataset.weights,
                                      detection=True,
                                      caption=False)

        batch_pred_map_cls = parse_predictions(data, POST_DICT)
        batch_gt_map_cls = parse_groundtruths(data, POST_DICT)
        for ap_calculator in AP_CALCULATOR_LIST:
            ap_calculator.step(batch_pred_map_cls, batch_gt_map_cls)

    # aggregate object detection results and report
    for i, ap_calculator in enumerate(AP_CALCULATOR_LIST):
        print()
        print("-" * 10, "iou_thresh: %f" % (AP_IOU_THRESHOLDS[i]), "-" * 10)
        metrics_dict = ap_calculator.compute_metrics()
        for key in metrics_dict:
            print("eval %s: %f" % (key, metrics_dict[key]))
예제 #4
0
def eval_det(args):
    print("evaluate detection...")
    # constant
    DC = ScannetDatasetConfig()

    # init training dataset
    print("preparing data...")
    scanrefer, scene_list = get_scanrefer(args)

    # dataloader
    _, dataloader = get_dataloader(args, scanrefer, scene_list, "val", DC)

    # model
    model = get_model(args, DC)

    # config
    POST_DICT = {
        "remove_empty_box": True,
        "use_3d_nms": True,
        "nms_iou": 0.25,
        "use_old_type_nms": False,
        "cls_nms": True,
        "per_class_proposal": True,
        "conf_thresh": 0.05,
        "dataset_config": DC
    }
    AP_IOU_THRESHOLDS = [0.25, 0.5]
    AP_CALCULATOR_LIST = [
        APCalculator(iou_thresh, DC.class2type)
        for iou_thresh in AP_IOU_THRESHOLDS
    ]

    sem_acc = []
    for data in tqdm(dataloader):
        for key in data:
            data[key] = data[key].cuda()

        # feed
        with torch.no_grad():
            data = model(data)
            _, data = get_loss(data_dict=data,
                               config=DC,
                               detection=True,
                               reference=False)
            data = get_eval(data_dict=data,
                            config=DC,
                            reference=False,
                            post_processing=POST_DICT)

        sem_acc.append(data["sem_acc"].item())

        batch_pred_map_cls = parse_predictions(data, POST_DICT)
        batch_gt_map_cls = parse_groundtruths(data, POST_DICT)
        for ap_calculator in AP_CALCULATOR_LIST:
            ap_calculator.step(batch_pred_map_cls, batch_gt_map_cls)

    # aggregate object detection results and report
    print("\nobject detection sem_acc: {}".format(np.mean(sem_acc)))
    for i, ap_calculator in enumerate(AP_CALCULATOR_LIST):
        print()
        print("-" * 10, "iou_thresh: %f" % (AP_IOU_THRESHOLDS[i]), "-" * 10)
        metrics_dict = ap_calculator.compute_metrics()
        for key in metrics_dict:
            print("eval %s: %f" % (key, metrics_dict[key]))
예제 #5
0
def get_loss(data_dict,
             config,
             reference=False,
             use_lang_classifier=False,
             use_max_iou=False,
             post_processing=None):
    """ Loss functions

    Args:
        data_dict: dict
        config: dataset config instance
        reference: flag (False/True)
        post_processing: config dict
    Returns:
        loss: pytorch scalar tensor
        data_dict: dict
    """

    # Vote loss
    vote_loss = compute_vote_loss(data_dict)
    data_dict['vote_loss'] = vote_loss

    # Obj loss
    objectness_loss, objectness_label, objectness_mask, object_assignment = compute_objectness_loss(
        data_dict)
    data_dict['objectness_loss'] = objectness_loss
    data_dict['objectness_label'] = objectness_label
    data_dict['objectness_mask'] = objectness_mask
    data_dict['object_assignment'] = object_assignment
    total_num_proposal = objectness_label.shape[0] * objectness_label.shape[1]
    data_dict['pos_ratio'] = torch.sum(
        objectness_label.float().cuda()) / float(total_num_proposal)
    data_dict['neg_ratio'] = torch.sum(objectness_mask.float()) / float(
        total_num_proposal) - data_dict['pos_ratio']

    # Box loss and sem cls loss
    center_loss, heading_cls_loss, heading_reg_loss, size_cls_loss, size_reg_loss, sem_cls_loss = compute_box_and_sem_cls_loss(
        data_dict, config)
    data_dict['center_loss'] = center_loss
    data_dict['heading_cls_loss'] = heading_cls_loss
    data_dict['heading_reg_loss'] = heading_reg_loss
    data_dict['size_cls_loss'] = size_cls_loss
    data_dict['size_reg_loss'] = size_reg_loss
    data_dict['sem_cls_loss'] = sem_cls_loss
    box_loss = center_loss + 0.1 * heading_cls_loss + heading_reg_loss + 0.1 * size_cls_loss + size_reg_loss
    data_dict['box_loss'] = box_loss

    if reference:
        # Reference loss
        ref_loss, lang_loss, cluster_preds_scores, cluster_labels = compute_reference_loss(
            data_dict, config, use_lang_classifier, use_max_iou)
        data_dict["ref_loss"] = ref_loss
        data_dict["lang_loss"] = lang_loss

        objectness_preds_batch = torch.argmax(data_dict['objectness_scores'],
                                              2).long()
        objectness_labels_batch = objectness_label.long()

        if post_processing:
            _ = parse_predictions(data_dict, post_processing)
            nms_masks = torch.LongTensor(data_dict['pred_mask']).cuda()

            # construct valid mask
            pred_masks = (nms_masks * objectness_preds_batch == 1).float()
            label_masks = (objectness_labels_batch == 1).float()
        else:
            # construct valid mask
            pred_masks = (objectness_preds_batch == 1).float()
            label_masks = (objectness_labels_batch == 1).float()

        data_dict["pred_mask"] = pred_masks
        data_dict["label_mask"] = label_masks

        cluster_preds = torch.argmax(cluster_preds_scores * pred_masks,
                                     1).long().unsqueeze(1).repeat(
                                         1, pred_masks.shape[1])
        preds = torch.zeros(pred_masks.shape).cuda()
        preds = preds.scatter_(1, cluster_preds, 1)
        cluster_preds = preds
        cluster_labels = cluster_labels.float()
        cluster_labels *= label_masks

        # compute classification scores
        corrects = torch.sum((cluster_preds == 1) * (cluster_labels == 1),
                             dim=1).float()
        labels = torch.ones(corrects.shape[0]).cuda()

        ref_acc = corrects / (labels + 1e-8)

        # store
        data_dict["ref_acc"] = ref_acc.cpu().numpy().tolist()

        # compute localization metrics
        pred_ref = torch.argmax(
            data_dict['cluster_ref'] * data_dict['pred_mask'],
            1).detach().cpu().numpy()  # (B,)
        pred_center = data_dict['center'].detach().cpu().numpy()  # (B,K,3)
        pred_heading_class = torch.argmax(data_dict['heading_scores'],
                                          -1)  # B,num_proposal
        pred_heading_residual = torch.gather(
            data_dict['heading_residuals'], 2,
            pred_heading_class.unsqueeze(-1))  # B,num_proposal,1
        pred_heading_class = pred_heading_class.detach().cpu().numpy(
        )  # B,num_proposal
        pred_heading_residual = pred_heading_residual.squeeze(
            2).detach().cpu().numpy()  # B,num_proposal
        pred_size_class = torch.argmax(data_dict['size_scores'],
                                       -1)  # B,num_proposal
        pred_size_residual = torch.gather(
            data_dict['size_residuals'], 2,
            pred_size_class.unsqueeze(-1).unsqueeze(-1).repeat(
                1, 1, 1, 3))  # B,num_proposal,1,3
        pred_size_class = pred_size_class.detach().cpu().numpy()
        pred_size_residual = pred_size_residual.squeeze(
            2).detach().cpu().numpy()  # B,num_proposal,3

        gt_ref = torch.argmax(data_dict["ref_box_label"],
                              1).detach().cpu().numpy()
        gt_center = data_dict['center_label'].cpu().numpy(
        )  # (B,MAX_NUM_OBJ,3)
        gt_heading_class = data_dict['heading_class_label'].cpu().numpy(
        )  # B,K2
        gt_heading_residual = data_dict['heading_residual_label'].cpu().numpy(
        )  # B,K2
        gt_size_class = data_dict['size_class_label'].cpu().numpy()  # B,K2
        gt_size_residual = data_dict['size_residual_label'].cpu().numpy(
        )  # B,K2,3

        ious = []
        multiple = []
        for i in range(pred_ref.shape[0]):
            # compute the iou
            pred_ref_idx, gt_ref_idx = pred_ref[i], gt_ref[i]
            pred_obb = config.param2obb(pred_center[i, pred_ref_idx, 0:3],
                                        pred_heading_class[i, pred_ref_idx],
                                        pred_heading_residual[i, pred_ref_idx],
                                        pred_size_class[i, pred_ref_idx],
                                        pred_size_residual[i, pred_ref_idx])
            gt_obb = config.param2obb(gt_center[i, gt_ref_idx, 0:3],
                                      gt_heading_class[i, gt_ref_idx],
                                      gt_heading_residual[i, gt_ref_idx],
                                      gt_size_class[i, gt_ref_idx],
                                      gt_size_residual[i, gt_ref_idx])
            pred_bbox = get_3d_box(pred_obb[3:6], pred_obb[6], pred_obb[0:3])
            gt_bbox = get_3d_box(gt_obb[3:6], gt_obb[6], gt_obb[0:3])
            iou, _ = box3d_iou(pred_bbox, gt_bbox)
            ious.append(iou)

            # construct the multiple mask
            num_bbox = data_dict["num_bbox"][i]
            sem_cls_label = data_dict["sem_cls_label"][i]
            sem_cls_label[num_bbox:] -= 1
            num_choices = torch.sum(
                data_dict["object_cat"][i] == sem_cls_label)
            if num_choices > 1:
                multiple.append(1)
            else:
                multiple.append(0)

        # store
        data_dict["ref_iou"] = ious
        data_dict["ref_iou_rate_0.25"] = np.array(ious)[
            np.array(ious) >= 0.25].shape[0] / np.array(ious).shape[0]
        data_dict["ref_iou_rate_0.5"] = np.array(ious)[
            np.array(ious) >= 0.5].shape[0] / np.array(ious).shape[0]
        data_dict["ref_multiple_mask"] = multiple
    else:
        ref_loss = torch.zeros(1)[0].cuda()
        lang_loss = torch.zeros(1)[0].cuda()

    # Final loss function
    if use_max_iou:
        loss = vote_loss + 0.5 * objectness_loss + box_loss + 0.1 * sem_cls_loss + 0.1 * ref_loss + lang_loss
    else:
        loss = vote_loss + 0.5 * objectness_loss + box_loss + 0.1 * sem_cls_loss + 0.01 * ref_loss + lang_loss

    loss *= 10  # amplify

    data_dict['loss'] = loss

    # --------------------------------------------
    # Some other statistics
    obj_pred_val = torch.argmax(data_dict['objectness_scores'], 2)  # B,K
    obj_acc = torch.sum((obj_pred_val == objectness_label.long()).float() *
                        objectness_mask) / (torch.sum(objectness_mask) + 1e-6)
    data_dict['obj_acc'] = obj_acc
    # precision, recall, f1
    corrects = torch.sum((obj_pred_val == 1) * (objectness_label == 1),
                         dim=1).float()
    preds = torch.sum(obj_pred_val == 1, dim=1).float()
    labels = torch.sum(objectness_label == 1, dim=1).float()
    precisions = corrects / (labels + 1e-8)
    recalls = corrects / (preds + 1e-8)
    f1s = 2 * precisions * recalls / (precisions + recalls + 1e-8)
    data_dict["objectness_precision"] = precisions.cpu().numpy().tolist()
    data_dict["objectness_recall"] = recalls.cpu().numpy().tolist()
    data_dict["objectness_f1"] = f1s.cpu().numpy().tolist()
    # lang
    if use_lang_classifier:
        data_dict["lang_acc"] = (torch.argmax(
            data_dict['lang_scores'],
            1) == data_dict["object_cat"]).float().mean()
    else:
        data_dict["lang_acc"] = torch.zeros(1)[0].cuda()

    return loss, data_dict
예제 #6
0
def get_eval(data_dict,
             config,
             reference,
             use_lang_classifier=False,
             use_oracle=False,
             use_cat_rand=False,
             use_best=False,
             post_processing=None):
    """ Loss functions

    Args:
        data_dict: dict
        config: dataset config instance
        reference: flag (False/True)
        post_processing: config dict
    Returns:
        loss: pytorch scalar tensor
        data_dict: dict
    """

    batch_size, num_words, _ = data_dict["lang_feat"].shape

    objectness_preds_batch = torch.argmax(data_dict['objectness_scores'],
                                          2).long()
    objectness_labels_batch = data_dict['objectness_label'].long()

    if post_processing:
        _ = parse_predictions(data_dict, post_processing)
        nms_masks = torch.LongTensor(data_dict['pred_mask']).cuda()

        # construct valid mask
        pred_masks = (nms_masks * objectness_preds_batch == 1).float()
        label_masks = (objectness_labels_batch == 1).float()
    else:
        # construct valid mask
        pred_masks = (objectness_preds_batch == 1).float()
        label_masks = (objectness_labels_batch == 1).float()

    cluster_preds = torch.argmax(data_dict["cluster_ref"] * pred_masks,
                                 1).long().unsqueeze(1).repeat(
                                     1, pred_masks.shape[1])
    preds = torch.zeros(pred_masks.shape).cuda()
    preds = preds.scatter_(1, cluster_preds, 1)
    cluster_preds = preds
    cluster_labels = data_dict["cluster_labels"].float()
    cluster_labels *= label_masks

    # compute classification scores
    corrects = torch.sum((cluster_preds == 1) * (cluster_labels == 1),
                         dim=1).float()
    labels = torch.ones(corrects.shape[0]).cuda()
    ref_acc = corrects / (labels + 1e-8)

    # store
    data_dict["ref_acc"] = ref_acc.cpu().numpy().tolist()

    # compute localization metrics
    if use_best:
        pred_ref = torch.argmax(data_dict["cluster_labels"], 1)  # (B,)
        # store the calibrated predictions and masks
        data_dict['cluster_ref'] = data_dict["cluster_labels"]
    if use_cat_rand:
        cluster_preds = torch.zeros(cluster_labels.shape).cuda()
        for i in range(cluster_preds.shape[0]):
            num_bbox = data_dict["num_bbox"][i]
            sem_cls_label = data_dict["sem_cls_label"][i]
            # sem_cls_label = torch.argmax(end_points["sem_cls_scores"], 2)[i]
            sem_cls_label[num_bbox:] -= 1
            candidate_masks = torch.gather(
                sem_cls_label == data_dict["object_cat"][i], 0,
                data_dict["object_assignment"][i])
            candidates = torch.arange(cluster_labels.shape[1])[candidate_masks]
            try:
                chosen_idx = torch.randperm(candidates.shape[0])[0]
                chosen_candidate = candidates[chosen_idx]
                cluster_preds[i, chosen_candidate] = 1
            except IndexError:
                cluster_preds[i, candidates] = 1

        pred_ref = torch.argmax(cluster_preds, 1)  # (B,)
        # store the calibrated predictions and masks
        data_dict['cluster_ref'] = cluster_preds
    else:
        pred_ref = torch.argmax(data_dict['cluster_ref'] * pred_masks,
                                1)  # (B,)
        # store the calibrated predictions and masks
        data_dict['cluster_ref'] = data_dict['cluster_ref'] * pred_masks

    if use_oracle:
        pred_center = data_dict['center_label']  # (B,MAX_NUM_OBJ,3)
        pred_heading_class = data_dict['heading_class_label']  # B,K2
        pred_heading_residual = data_dict['heading_residual_label']  # B,K2
        pred_size_class = data_dict['size_class_label']  # B,K2
        pred_size_residual = data_dict['size_residual_label']  # B,K2,3

        # assign
        pred_center = torch.gather(
            pred_center, 1,
            data_dict["object_assignment"].unsqueeze(2).repeat(1, 1, 3))
        pred_heading_class = torch.gather(pred_heading_class, 1,
                                          data_dict["object_assignment"])
        pred_heading_residual = torch.gather(
            pred_heading_residual, 1,
            data_dict["object_assignment"]).unsqueeze(-1)
        pred_size_class = torch.gather(pred_size_class, 1,
                                       data_dict["object_assignment"])
        pred_size_residual = torch.gather(
            pred_size_residual, 1,
            data_dict["object_assignment"].unsqueeze(2).repeat(1, 1, 3))
    else:
        pred_center = data_dict['center']  # (B,K,3)
        pred_heading_class = torch.argmax(data_dict['heading_scores'],
                                          -1)  # B,num_proposal
        pred_heading_residual = torch.gather(
            data_dict['heading_residuals'], 2,
            pred_heading_class.unsqueeze(-1))  # B,num_proposal,1
        pred_heading_class = pred_heading_class  # B,num_proposal
        pred_heading_residual = pred_heading_residual.squeeze(
            2)  # B,num_proposal
        pred_size_class = torch.argmax(data_dict['size_scores'],
                                       -1)  # B,num_proposal
        pred_size_residual = torch.gather(
            data_dict['size_residuals'], 2,
            pred_size_class.unsqueeze(-1).unsqueeze(-1).repeat(
                1, 1, 1, 3))  # B,num_proposal,1,3
        pred_size_class = pred_size_class
        pred_size_residual = pred_size_residual.squeeze(2)  # B,num_proposal,3

    # store
    data_dict["pred_mask"] = pred_masks
    data_dict["label_mask"] = label_masks
    data_dict['pred_center'] = pred_center
    data_dict['pred_heading_class'] = pred_heading_class
    data_dict['pred_heading_residual'] = pred_heading_residual
    data_dict['pred_size_class'] = pred_size_class
    data_dict['pred_size_residual'] = pred_size_residual

    gt_ref = torch.argmax(data_dict["ref_box_label"], 1)
    gt_center = data_dict['center_label']  # (B,MAX_NUM_OBJ,3)
    gt_heading_class = data_dict['heading_class_label']  # B,K2
    gt_heading_residual = data_dict['heading_residual_label']  # B,K2
    gt_size_class = data_dict['size_class_label']  # B,K2
    gt_size_residual = data_dict['size_residual_label']  # B,K2,3

    ious = []
    multiple = []
    others = []
    pred_bboxes = []
    gt_bboxes = []
    for i in range(pred_ref.shape[0]):
        # compute the iou
        pred_ref_idx, gt_ref_idx = pred_ref[i], gt_ref[i]
        pred_obb = config.param2obb(
            pred_center[i, pred_ref_idx, 0:3].detach().cpu().numpy(),
            pred_heading_class[i, pred_ref_idx].detach().cpu().numpy(),
            pred_heading_residual[i, pred_ref_idx].detach().cpu().numpy(),
            pred_size_class[i, pred_ref_idx].detach().cpu().numpy(),
            pred_size_residual[i, pred_ref_idx].detach().cpu().numpy())
        gt_obb = config.param2obb(
            gt_center[i, gt_ref_idx, 0:3].detach().cpu().numpy(),
            gt_heading_class[i, gt_ref_idx].detach().cpu().numpy(),
            gt_heading_residual[i, gt_ref_idx].detach().cpu().numpy(),
            gt_size_class[i, gt_ref_idx].detach().cpu().numpy(),
            gt_size_residual[i, gt_ref_idx].detach().cpu().numpy())
        pred_bbox = get_3d_box(pred_obb[3:6], pred_obb[6], pred_obb[0:3])
        gt_bbox = get_3d_box(gt_obb[3:6], gt_obb[6], gt_obb[0:3])
        iou = eval_ref_one_sample(pred_bbox, gt_bbox)
        ious.append(iou)

        # NOTE: get_3d_box() will return problematic bboxes
        pred_bbox = construct_bbox_corners(pred_obb[0:3], pred_obb[3:6])
        gt_bbox = construct_bbox_corners(gt_obb[0:3], gt_obb[3:6])
        pred_bboxes.append(pred_bbox)
        gt_bboxes.append(gt_bbox)

        # construct the multiple mask
        multiple.append(data_dict["unique_multiple"][i].item())

        # construct the others mask
        flag = 1 if data_dict["object_cat"][i] == 17 else 0
        others.append(flag)

    # lang
    if reference and use_lang_classifier:
        data_dict["lang_acc"] = (torch.argmax(
            data_dict['lang_scores'],
            1) == data_dict["object_cat"]).float().mean()
    else:
        data_dict["lang_acc"] = torch.zeros(1)[0].cuda()

    # store
    data_dict["ref_iou"] = ious
    data_dict["ref_iou_rate_0.25"] = np.array(ious)[
        np.array(ious) >= 0.25].shape[0] / np.array(ious).shape[0]
    data_dict["ref_iou_rate_0.5"] = np.array(ious)[
        np.array(ious) >= 0.5].shape[0] / np.array(ious).shape[0]
    data_dict["ref_multiple_mask"] = multiple
    data_dict["ref_others_mask"] = others
    data_dict["pred_bboxes"] = pred_bboxes
    data_dict["gt_bboxes"] = gt_bboxes

    # --------------------------------------------
    # Some other statistics
    obj_pred_val = torch.argmax(data_dict['objectness_scores'], 2)  # B,K
    obj_acc = torch.sum(
        (obj_pred_val == data_dict['objectness_label'].long()).float() *
        data_dict['objectness_mask']) / (
            torch.sum(data_dict['objectness_mask']) + 1e-6)
    data_dict['obj_acc'] = obj_acc
    # detection semantic classification
    sem_cls_label = torch.gather(
        data_dict['sem_cls_label'], 1,
        data_dict['object_assignment'])  # select (B,K) from (B,K2)
    sem_cls_pred = data_dict['sem_cls_scores'].argmax(-1)  # (B,K)
    sem_match = (sem_cls_label == sem_cls_pred).float()
    data_dict["sem_acc"] = (sem_match * data_dict["pred_mask"]
                            ).sum() / data_dict["pred_mask"].sum()

    return data_dict
예제 #7
0
def get_eval(data_dict,
             config,
             reference,
             use_lang_classifier=False,
             use_oracle=False,
             use_cat_rand=False,
             use_best=False,
             post_processing=None):
    """ Loss functions

    Args:
        data_dict: dict
        config: dataset config instance
        reference: flag (False/True)
        post_processing: config dict
    Returns:
        loss: pytorch scalar tensor
        data_dict: dict
    """

    batch_size, num_words, _ = data_dict["lang_feat"].shape

    #objectness_preds_batch = torch.argmax(data_dict['objectness_scores'], 2).long()
    #objectness_labels_batch = data_dict['objectness_label'].long()

    if post_processing:
        _ = parse_predictions(data_dict, post_processing)
        nms_masks = torch.LongTensor(data_dict['pred_mask']).cuda()

        # construct valid mask
        #pred_masks = (nms_masks * objectness_preds_batch == 1).float()
        #label_masks = (objectness_labels_batch == 1).float()
    else:
        # construct valid mask
        #pred_masks = (objectness_preds_batch == 1).float()
        #label_masks = (objectness_labels_batch == 1).float()
        pass

    #cluster_preds = torch.argmax(data_dict["cluster_ref"] * pred_masks, 1).long().unsqueeze(1).repeat(1, pred_masks.shape[1])
    #preds = torch.zeros(pred_masks.shape).cuda()
    #preds = preds.scatter_(1, cluster_preds, 1)
    #cluster_preds = pred
    #cluster_labels *= label_masks

    cluster_preds = data_dict["cluster_ref"]  # (B*num_proposal)
    cluster_labels = data_dict["cluster_labels"].float()

    #preds = torch.zeros_like(cluster_lables)
    #preds = preds.scatter_(1, cluster_preds.argmax(dim=1), 1)
    preds = cluster_preds.argmax(dim=1).cuda()
    target_preds = cluster_labels.argmax(dim=1).cuda()

    # compute classification scores
    #corrects = torch.sum((preds == 1) * (cluster_labels == 1), dim=1).float()
    corrects = (preds == target_preds).float()
    labels = torch.ones(corrects.shape[0]).cuda()
    ref_acc = corrects / (labels + 1e-8)

    # store
    data_dict["ref_acc"] = ref_acc.cpu().numpy().tolist()

    # compute localization metrics
    if use_best:
        pred_ref = torch.argmax(data_dict["cluster_labels"], 1)  # (B,)
        # store the calibrated predictions and masks
        data_dict['cluster_ref'] = data_dict["cluster_labels"]
    # TODO: remove 'and False' in case this could also be of use without bboxes
    if use_cat_rand and False:
        cluster_preds = torch.zeros(cluster_labels.shape).cuda()
        for i in range(cluster_preds.shape[0]):
            num_bbox = data_dict["num_bbox"][i]
            sem_cls_label = data_dict["sem_cls_label"][i]
            # sem_cls_label = torch.argmax(end_points["sem_cls_scores"], 2)[i]
            sem_cls_label[num_bbox:] -= 1
            candidate_masks = torch.gather(
                sem_cls_label == data_dict["object_cat"][i], 0,
                data_dict["object_assignment"][i])
            candidates = torch.arange(cluster_labels.shape[1])[candidate_masks]
            try:
                chosen_idx = torch.randperm(candidates.shape[0])[0]
                chosen_candidate = candidates[chosen_idx]
                cluster_preds[i, chosen_candidate] = 1
            except IndexError:
                cluster_preds[i, candidates] = 1

        pred_ref = torch.argmax(cluster_preds, 1)  # (B,)
        # store the calibrated predictions and masks
        data_dict['cluster_ref'] = cluster_preds
    else:
        # TODO: remove (and calculate somewhere) pred_masks=1
        #       in case we also need to construct a mask to filter our predictions
        pred_masks = 1
        pred_ref = torch.argmax(data_dict['cluster_ref'] * pred_masks,
                                1)  # (B,)
        # TODO: uncomment for filtering
        # store the calibrated predictions and masks
        #data_dict['cluster_ref'] = data_dict['cluster_ref'] * pred_masks

    # TODO: for now we don't use oracle (undo: remove 'and False')
    if use_oracle and False:
        pred_center = data_dict['center_label']  # (B,MAX_NUM_OBJ,3)
        pred_heading_class = data_dict['heading_class_label']  # B,K2
        pred_heading_residual = data_dict['heading_residual_label']  # B,K2
        pred_size_class = data_dict['size_class_label']  # B,K2
        pred_size_residual = data_dict['size_residual_label']  # B,K2,3

        # assign
        pred_center = torch.gather(
            pred_center, 1,
            data_dict["object_assignment"].unsqueeze(2).repeat(1, 1, 3))
        pred_heading_class = torch.gather(pred_heading_class, 1,
                                          data_dict["object_assignment"])
        pred_heading_residual = torch.gather(
            pred_heading_residual, 1,
            data_dict["object_assignment"]).unsqueeze(-1)
        pred_size_class = torch.gather(pred_size_class, 1,
                                       data_dict["object_assignment"])
        pred_size_residual = torch.gather(
            pred_size_residual, 1,
            data_dict["object_assignment"].unsqueeze(2).repeat(1, 1, 3))
    # TODO: for now we don't use this either (undo: elif -> else)
    elif False:
        pred_center = data_dict['center']  # (B,K,3)
        pred_heading_class = torch.argmax(data_dict['heading_scores'],
                                          -1)  # B,num_proposal
        pred_heading_residual = torch.gather(
            data_dict['heading_residuals'], 2,
            pred_heading_class.unsqueeze(-1))  # B,num_proposal,1
        pred_heading_class = pred_heading_class  # B,num_proposal
        pred_heading_residual = pred_heading_residual.squeeze(
            2)  # B,num_proposal
        pred_size_class = torch.argmax(data_dict['size_scores'],
                                       -1)  # B,num_proposal
        pred_size_residual = torch.gather(
            data_dict['size_residuals'], 2,
            pred_size_class.unsqueeze(-1).unsqueeze(-1).repeat(
                1, 1, 1, 3))  # B,num_proposal,1,3
        pred_size_class = pred_size_class
        pred_size_residual = pred_size_residual.squeeze(2)  # B,num_proposal,3

    # store
    #data_dict["pred_mask"] = pred_masks
    #data_dict["label_mask"] = label_masks
    #data_dict['pred_center'] = pred_center
    #data_dict['pred_heading_class'] = pred_heading_class
    #data_dict['pred_heading_residual'] = pred_heading_residual
    #data_dict['pred_size_class'] = pred_size_class
    #data_dict['pred_size_residual'] = pred_size_residual

    #gt_ref = torch.argmax(data_dict["ref_box_label"], 1)
    #gt_center = data_dict['center_label'] # (B,MAX_NUM_OBJ,3)
    #gt_heading_class = data_dict['heading_class_label'] # B,K2
    #gt_heading_residual = data_dict['heading_residual_label'] # B,K2
    #gt_size_class = data_dict['size_class_label'] # B,K2
    #gt_size_residual = data_dict['size_residual_label'] # B,K2,3

    ious = []
    #multiple = []
    others = []
    pred_bboxes = []
    gt_bboxes = []

    ### More Info (incl. comments) ###
    # in compute_reference_loss in loss_helper.py (same process)
    start_of_samples = data_dict['offsets']  # (B)
    gt_instances = data_dict['instance_labels']  # (B*N)
    target_inst_id = data_dict['object_id']  # (B)
    preds_offsets = data_dict['proposals_offset']
    proposal_batch_ids = data_dict['proposal_batch_ids']
    preds_instances = data_dict['proposals_idx']  # (B*sumNPoint, 2)
    batch_size, num_proposals = cluster_preds.shape
    total_num_proposals = len(preds_offsets) - 1
    # for every batch
    for i in range(batch_size):
        # compute the iou
        #pred_ref_idx, gt_ref_idx = pred_ref[i], gt_ref[i]
        #pred_obb = config.param2obb(
        #    pred_center[i, pred_ref_idx, 0:3].detach().cpu().numpy(),
        #    pred_heading_class[i, pred_ref_idx].detach().cpu().numpy(),
        #    pred_heading_residual[i, pred_ref_idx].detach().cpu().numpy(),
        #    pred_size_class[i, pred_ref_idx].detach().cpu().numpy(),
        #    pred_size_residual[i, pred_ref_idx].detach().cpu().numpy()
        #3)
        #gt_obb = config.param2obb(
        #    gt_center[i, gt_ref_idx, 0:3].detach().cpu().numpy(),
        #    gt_heading_class[i, gt_ref_idx].detach().cpu().numpy(),
        #    gt_heading_residual[i, gt_ref_idx].detach().cpu().numpy(),
        #    gt_size_class[i, gt_ref_idx].detach().cpu().numpy(),
        #    gt_size_residual[i, gt_ref_idx].detach().cpu().numpy()
        #)
        #pred_bbox = get_3d_box(pred_obb[3:6], pred_obb[6], pred_obb[0:3])
        #gt_bbox = get_3d_box(gt_obb[3:6], gt_obb[6], gt_obb[0:3])

        start = start_of_samples[i]
        end = start_of_samples[i + 1]

        correct_indices = (torch.arange(
            len(gt_instances))[gt_instances == target_inst_id[i]]).cuda()
        numbSamplePerCluster = torch.zeros(total_num_proposals)
        iou = torch.zeros(total_num_proposals)

        # get correct window of preds_instances (is unordered)
        # as is done in match_module.py and loss_helper.py
        correct_proposals = data_dict['proposals_offset'][:-1][
            proposal_batch_ids == i]
        for j in range(len(correct_proposals) - 1):
            start_correct_proposals = correct_proposals[j]
            end_correct_proposals = torch.nonzero(
                preds_offsets == correct_proposals[j]) + 1
            end_correct_proposals = preds_offsets[end_correct_proposals]
            preds_instance_proposals = preds_instances[
                start_correct_proposals:end_correct_proposals]

            cluster_ids, member_points = preds_instance_proposals[:,
                                                                  0], preds_instance_proposals[:, 1].long(
                                                                  )
            cluster_id = cluster_ids[0]
            numbSamplePerCluster[cluster_id] = cluster_ids.shape[0]
            combined = torch.cat((member_points, correct_indices))
            _, counts = combined.unique(return_counts=True)
            numb_object_id_proposals = counts[counts > 1].shape[0]
            iou[cluster_id] = numb_object_id_proposals / (
                combined.shape[0] - numb_object_id_proposals)

        scene_num_proposals = (proposal_batch_ids == i).sum()
        scene_iou = iou[proposal_batch_ids == i]
        scene_iou = scene_iou
        cluster_preds_scene = cluster_preds[i][:scene_num_proposals]
        if cluster_preds_scene.shape[0] > 0:
            high_conf_cluster_pred = torch.argmax(cluster_preds_scene)
            ious.append(scene_iou[high_conf_cluster_pred].unsqueeze(0))
        else:
            ious.append(0)
        # NOTE: get_3d_box() will return problematic bboxes
        #pred_bbox = construct_bbox_corners(pred_obb[0:3], pred_obb[3:6])
        #gt_bbox = construct_bbox_corners(gt_obb[0:3], gt_obb[3:6])
        #pred_bboxes.append(pred_bbox)
        #gt_bboxes.append(gt_bbox)

        # construct the multiple mask
        #multiple.append(data_dict["unique_multiple"][i].item())

        # construct the others mask
        #flag = 1 if data_dict["object_cat"][i] == 17 else 0
        #others.append(flag)

    # lang
    if reference and use_lang_classifier:
        data_dict["lang_acc"] = (torch.argmax(
            data_dict['lang_scores'],
            1) == data_dict["object_cat"]).float().mean()
    else:
        data_dict["lang_acc"] = torch.zeros(1)[0].cuda()

    # store
    data_dict["ref_iou"] = ious
    ious = torch.cat(ious).numpy()
    data_dict["ref_iou_rate_0.25"] = ious[
        ious >= 0.25].shape[0] / ious.shape[0]
    data_dict["ref_iou_rate_0.5"] = ious[ious >= 0.5].shape[0] / ious.shape[0]
    #data_dict["ref_multiple_mask"] = multiple
    #data_dict["ref_others_mask"] = others
    #data_dict["pred_bboxes"] = pred_bboxes
    #data_dict["gt_bboxes"] = gt_bboxes

    # --------------------------------------------
    # Some other statistics
    #obj_pred_val = torch.argmax(data_dict['objectness_scores'], 2) # B,K
    #obj_acc = torch.sum((obj_pred_val==data_dict['objectness_label'].long()).float()*data_dict['objectness_mask'])/(torch.sum(data_dict['objectness_mask'])+1e-6)
    #data_dict['obj_acc'] = obj_acc
    # detection semantic classification
    #sem_cls_label = torch.gather(data_dict['sem_cls_label'], 1, data_dict['object_assignment']) # select (B,K) from (B,K2)
    #sem_cls_pred = data_dict['sem_cls_scores'].argmax(-1) # (B,K)
    #sem_match = (sem_cls_label == sem_cls_pred).float()
    # TODO: we may include sem_acc (structure probably only has to be slightly changed)
    #data_dict["sem_acc"] = (sem_match * data_dict["pred_mask"]).sum() / data_dict["pred_mask"].sum()

    return data_dict