예제 #1
0
    def forward(self, example, ret_dict):
        # get device
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # get input sp tensor
        input_dict = get_input(ret_dict, self.voxel_in_key)
        sample_sp = input_dict['input_sp']  # spconv tensor

        # get rgb features
        rgb_features = get_input(ret_dict, self.rgb_in_key)  # torch tensor

        # get true batch calib
        calib = example['calib']

        # generate dummy inputs
        batch_size = sample_sp.batch_size
        dummy_batch_list = input_dict['dummy_batch_list']
        dummy_rgb_features_list = []
        dummy_calib_dict = {'P2_list': [], 'Trv2c_list': [], 'rect_list': []}
        dummy_pr_list = []
        for index in range(batch_size):
            dummy_rgb_features_list.append(
                rgb_features[dummy_batch_list[index]].unsqueeze(0))
            for key in ['P2', 'Trv2c', 'rect']:
                dummy_calib_dict[f'{key}_list'].append(
                    torch.tensor(calib[key][dummy_batch_list[index]],
                                 device=device).unsqueeze(0))
            dummy_pr_list.append(
                input_dict['rgb_coor_refine'][dummy_batch_list[index]])
        for key in ['P2', 'Trv2c', 'rect']:
            dummy_calib_dict[key] = torch.cat(dummy_calib_dict[f'{key}_list'],
                                              dim=0)
        dummy_rgb_features = torch.cat(dummy_rgb_features_list, dim=0)

        # expand active indices
        expanded = simple_sp_tensor_expansion(sample_sp,
                                              self.expansion['kernel'],
                                              self.expansion['padding'],
                                              computation_device=device)

        # merge rgb features to voxels
        merged_feature = fuse_rgb_to_voxel(expanded, [0.1, 0.05, 0.05],
                                           [-3, -40, 0],
                                           dummy_rgb_features,
                                           dummy_calib_dict,
                                           pixel_refinement=dummy_pr_list)

        return {self.out_key: merged_feature}
예제 #2
0
    def forward(self, example, ret_dict):
        input_sp = get_input(ret_dict, self.in_key)

        # stem
        m_ret_dict = {'stem': self.stem(input_sp)}

        # encoders
        for level in range(1, len(self.channels)):
            encoder = getattr(self, f'encoder_{level}')
            m_ret_dict[f'enc_{level}'] = encoder(m_ret_dict['stem'] if level == 1 else m_ret_dict[f'enc_{level - 1}'])

        # decoders
        for level in reversed(range(1, len(self.channels))):
            decoder = getattr(self, f'decoder_{level}')
            inter_res = m_ret_dict[f'enc_{level}'] if level == len(self.channels) - 1 else concat_sp_tensors(
                m_ret_dict[f'enc_{level}'], m_ret_dict[f'dec_{level + 1}'], self.feature_concat)
            if self.feature_concat and level != len(self.channels) - 1:
                inter_res = getattr(self, f'conv1x1_{level}')(inter_res)
            m_ret_dict[f'dec_{level}'] = decoder(inter_res)

        # post result
        res = concat_sp_tensors(m_ret_dict['stem'], m_ret_dict['dec_1'], feature_concat=self.feature_concat)
        if self.feature_concat:
            res = getattr(self, 'conv1x1_0')(res)
        m_ret_dict['res'] = res

        return {self.out_key: m_ret_dict}
예제 #3
0
    def forward(self, example, ret_dict):
        input_sp = get_input(ret_dict, self.in_key)

        # stage 1
        o_1_b = self.stage_1_basic(input_sp)
        o_1_0, o_1_1 = self.stage_1_mr(o_1_b)

        # stage 2
        o_2_0_b = self.stage_2_basic_0(o_1_0)
        o_2_1_b = self.stage_2_basic_1(o_1_1)
        o_2_0, o_2_1, o_2_2 = self.stage_2_mr(o_2_0_b, o_2_1_b)

        # stage 3
        o_3_0_b = self.stage_3_basic_0(o_2_0)
        o_3_1_b = self.stage_3_basic_1(o_2_1)
        o_3_2_b = self.stage_3_basic_2(o_2_2)
        o_3_0, o_3_1, o_3_2, o_3_3 = self.stage_3_mr(o_3_0_b, o_3_1_b, o_3_2_b)

        # stage 4
        o_4_0_b = self.stage_4_basic_0(o_3_0)
        o_4_1_b = self.stage_4_basic_1(o_3_1)
        o_4_2_b = self.stage_4_basic_2(o_3_2)
        o_4_3_b = self.stage_4_basic_3(o_3_3)
        o_4_0, _, _, _ = self.stage_4_mr(o_4_0_b, o_4_1_b, o_4_2_b, o_4_3_b)

        # stage 5
        o_5_0_b = self.stage_5_basic_0(o_4_0)
        o_5_0_b_2 = self.stage_5_basic_0_2(o_5_0_b)

        return {self.out_key: o_5_0_b_2}
예제 #4
0
    def forward(self, example, ret_dict):
        # get input rgb
        rgb = get_input(ret_dict, self.in_key)

        # normalize input rgb
        rgb = rgb.permute(0, 2, 3, 1)  # B, C, H, W -> B, H, W, C
        rgb -= torch.tensor(kitti_mean, dtype=rgb.dtype, device=rgb.device)
        rgb /= torch.tensor(kitti_std, dtype=rgb.dtype, device=rgb.device)
        rgb = rgb.permute(0, 3, 1, 2)  # B, H, W, C -> B, C, H, W

        # get rgb features
        rgb_res_dict = self.model(rgb)

        # merge features
        res_feature = None
        for stage in self.use_stages:
            feature = self.fuse_stage_result(rgb_res_dict[stage])
            feature = getattr(self, f'conv1x1_{stage}')(feature)
            if res_feature is None:
                res_feature = feature
            else:
                res_feature += feature
        res_feature = self.post_process(res_feature)

        return {self.out_key: res_feature}
예제 #5
0
def generator(dataloader, net, confidence, input_key, pred_key, gt_key):
    for example in dataloader:
        # network feed-forward
        with torch.no_grad():
            res = net(example)
        ret_dict = res[-1]

        # get indices and output features
        in_indices = get_input(ret_dict, input_key).indices
        out_features = torch.sigmoid(get_input(ret_dict, pred_key).features)
        out_indices = get_input(ret_dict, pred_key).indices
        if gt_key:
            gt_indices = get_input(ret_dict, gt_key) if isinstance(get_input(ret_dict, gt_key), torch.Tensor) \
                else get_input(ret_dict, gt_key).indices
        else:
            gt_indices = torch.zeros(0,
                                     4,
                                     dtype=torch.int32,
                                     device=in_indices.device)

        # get coordinates
        in_coors = indices_2_coors(in_indices)
        out_coors = indices_2_coors(out_indices)
        out_coors = choose_coors(out_coors, out_features, confidence)
        gt_coors = indices_2_coors(gt_indices)

        # group coordinates together
        batch_size = get_input(ret_dict, input_key).batch_size
        in_coors_list = group_coors(in_coors, batch_size)
        out_coors_list = group_coors(out_coors, batch_size)
        gt_coors_list = group_coors(gt_coors, batch_size)

        # return results
        yield in_coors_list, out_coors_list, gt_coors_list
예제 #6
0
    def forward(self, example, ret_dict):
        # get input sparse conv tensor
        res = get_input(ret_dict, self.in_key)

        # go through blocks
        for index in range(self.block_num):
            res = getattr(self, f'block_{index}')(res)

        # post block process
        res = self.post_block(res)

        return {self.out_key: res}
예제 #7
0
    def loss(self, example, ff_ret_dict):
        # get device
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # get final output sp tensor of the network
        pred_sp = get_input(ff_ret_dict, self.out_key)

        # get preprocessor output dict
        pre_dict = get_input(ff_ret_dict, self.pre_key)

        # get predicted indices and features (sigmoid function applied)
        pred_indices = pred_sp.indices
        pred_features = torch.sigmoid(pred_sp.features)

        # get true ground truth indices
        label_indices = get_input(ff_ret_dict, self.gt_key)

        # generate dummy indices and features for loss calculation
        indices_cated = torch.cat((pred_indices, label_indices), dim=0)
        dummy_pred_indices, i_indices_index = torch.unique(indices_cated,
                                                           dim=0,
                                                           return_inverse=True)
        dummy_pred_features = torch.ones(dummy_pred_indices.shape[0],
                                         1,
                                         dtype=pred_features.dtype,
                                         device=pred_features.device)
        dummy_pred_features[
            i_indices_index[:pred_indices.shape[0]]] = pred_features

        # generate label features (with same shape as dummy_pred_features)
        label_features = torch.zeros(dummy_pred_features.shape[0],
                                     1,
                                     device=dummy_pred_features.device)
        concatenated = torch.cat((dummy_pred_indices, label_indices), dim=0)
        uniqued, i_index = torch.unique(concatenated,
                                        dim=0,
                                        return_inverse=True)
        label_features[i_index[dummy_pred_indices.shape[0]:]] = 1.0

        # project dummy_pred_indices to pixels
        coors = indices_to_coors(dummy_pred_indices,
                                 voxel_size=[0.1, 0.05, 0.05],
                                 offset=[-3.0, -40.0, 0.0])
        calib = example['calib']
        batch_size = pred_sp.batch_size
        dummy_batch_list = pre_dict['dummy_batch_list']
        dummy_calib_dict = {'P2_list': [], 'Trv2c_list': [], 'rect_list': []}
        for index in range(batch_size):
            for key in ['P2', 'Trv2c', 'rect']:
                dummy_calib_dict[f'{key}_list'].append(
                    torch.tensor(calib[key][dummy_batch_list[index]],
                                 device=device).unsqueeze(0))
        for key in ['P2', 'Trv2c', 'rect']:
            dummy_calib_dict[key] = torch.cat(dummy_calib_dict[f'{key}_list'],
                                              dim=0)

        # pixel_location is location of pixel corresponding the 3d points
        pixel_location = coors_to_pixels(coors,
                                         dummy_calib_dict,
                                         pixel_refinement=None,
                                         post_process='none')
        for b in range(pred_sp.batch_size):
            b_idx = pixel_location[:, 0] == b
            pixel_location[b_idx, 1:] *= pre_dict['rgb_coor_refine'][
                dummy_batch_list[b]]['resize_scale']
        pixel_location = pixel_location.round()

        mean_dis = torch.zeros(pixel_location.shape[0],
                               1,
                               device=pixel_location.device)
        for batch in range(pred_sp.batch_size):
            dp_idx = dummy_pred_indices[:, 0] == batch
            shrinked_dummy_pred_indices = dummy_pred_indices[dp_idx, 1:]
            gt_idx = label_indices[:, 0] == batch
            shrinked_label_indices = label_indices[gt_idx, 1:]
            # if bg examples, continue
            if shrinked_label_indices.shape[0] == 0:
                continue

            dummy_dis_to_gt = dist_chamfer(
                torch.unsqueeze(shrinked_dummy_pred_indices, dim=0),
                torch.unsqueeze(shrinked_label_indices, dim=0))
            k = shrinked_label_indices.shape[
                0] if self.k > shrinked_label_indices.shape[0] else self.k
            k_nearest_dis, _ = dummy_dis_to_gt[0].topk(k,
                                                       dim=1,
                                                       largest=False,
                                                       sorted=True)
            batch_mean_dis = torch.mean(k_nearest_dis, dim=1, keepdim=True)
            mean_dis[dp_idx] = batch_mean_dis.float()

        pixel_w_dis = torch.cat((pixel_location, mean_dis.float()), dim=1)
        pixel_w_dis_np = pixel_w_dis.cpu().numpy()
        pixel_w_dis_sorted_index = torch.tensor(
            np.lexsort((pixel_w_dis_np[:, 3], pixel_w_dis_np[:, 2],
                        pixel_w_dis_np[:, 1], pixel_w_dis_np[:, 0]),
                       axis=0))
        pixel_w_dis_sorted = pixel_w_dis[pixel_w_dis_sorted_index]
        pixel_w_dis_sorted_unique, pixel_w_dis_sorted_unique_inv_idx, pixel_w_dis_sorted_unique_c = torch.unique(
            pixel_w_dis_sorted[:, :3],
            dim=0,
            return_inverse=True,
            return_counts=True)
        p_w_d_head = [
            torch.sum(pixel_w_dis_sorted_unique_c[:i]).tolist()
            for i in range(len(pixel_w_dis_sorted_unique_c))
        ]

        dummy_pred_features = dummy_pred_features[pixel_w_dis_sorted_index]
        label_features = label_features[pixel_w_dis_sorted_index]
        flag = torch.ones(label_features.shape[0] + self.m - 1,
                          1,
                          dtype=torch.uint8)
        for i in range(self.m):
            flag[torch.tensor(p_w_d_head) + i] = False
        flag = flag[:label_features.shape[0]]
        flag[label_features == 1.0] = True  # make sure gt always be labeled
        label_features_l = label_features[flag]
        dummy_pred_features_l = dummy_pred_features[flag]

        # calculate focal_loss
        focal_loss = FocalLoss()
        loss_f = focal_loss(dummy_pred_features_l, label_features_l)

        # calculate stat.
        dropped_f_label = (label_features.sum() -
                           label_features_l.sum()).cpu().item()

        set_p__p_gt = label_features == 0.0
        set_p_gt = (label_features == 1.0) & (dummy_pred_features != 1.0)
        set_gt__p_gt = dummy_pred_features == 1.0
        flag_p__p_gt = flag[set_p__p_gt].sum().item()
        flag_p_gt = flag[set_p_gt].sum().item()
        flag_gt__p_gt = flag[set_gt__p_gt].sum().item()
        total_p__p_gt = set_p__p_gt.sum().item()
        total_p_gt = set_p_gt.sum().item()
        total_gt__p_gt = set_gt__p_gt.sum().item()
        flag_test_dict = {
            'p__p_gt_total':
            total_p__p_gt,
            'p__p_gt_flag':
            flag_p__p_gt,
            'p__p_gt_ratio':
            0 if total_p__p_gt == 0 else flag_p__p_gt / total_p__p_gt,
            'p_gt_total':
            total_p_gt,
            'p_gt_flag':
            flag_p_gt,
            'p_gt_ratio':
            0 if total_p_gt == 0 else flag_p_gt / total_p_gt,
            'gt__p_gt_total':
            total_gt__p_gt,
            'gt__p_gt_flag':
            flag_gt__p_gt,
            'gt__p_gt_ratio':
            0 if total_gt__p_gt == 0 else flag_gt__p_gt / total_gt__p_gt,
        }

        gt_num = label_indices.shape[0]
        drop_gt_num = dummy_pred_indices.shape[0] - pred_indices.shape[0]
        pred_num = pred_indices.shape[0]

        if label_indices.shape[0] > 0:
            thres_acc_dict = {}
            thres_recall_dict = {}
            thres = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
            for thre in thres:
                num_dummy_labeled_gt = (
                    dummy_pred_features_l.detach().cpu() == 1.0).sum().float()
                # acc.
                thre_pred = (dummy_pred_features_l.detach().cpu() >
                             thre).float()
                thre_total = dummy_pred_features_l.detach().cpu().shape[0]
                thre_correct = ((thre_pred - label_features_l.cpu()).abs() <
                                0.001).sum()
                thres_acc_dict[f'{thre}'] = (
                    (thre_correct - num_dummy_labeled_gt) /
                    (thre_total - num_dummy_labeled_gt)).item()
                # recall
                thre_gt_idx = label_features_l == 1.0
                thre_recall_total = thre_gt_idx.sum().cpu()
                thre_recall_correct = (
                    dummy_pred_features_l[thre_gt_idx].detach().cpu() >
                    thre).sum()
                thres_recall_dict[f'{thre}'] = (
                    (thre_recall_correct - num_dummy_labeled_gt) /
                    (thre_recall_total - num_dummy_labeled_gt)).item()

        # Projection Perspective Constraint Loss loss
        label_features_ppc = label_features.clone()
        BUVF = torch.cat(
            (pixel_w_dis_sorted[:, :3], dummy_pred_features.detach()), dim=1)
        BUVF_np = BUVF.cpu().numpy()
        BUVF_sorted_idx = torch.tensor(
            np.lexsort(
                (-BUVF_np[:, 3], BUVF_np[:, 2], BUVF_np[:, 1], BUVF_np[:, 0]),
                axis=0))
        BUVF_sorted = BUVF[BUVF_sorted_idx]
        label_features_ppc = label_features_ppc[BUVF_sorted_idx]
        label_features_ppc[torch.tensor(p_w_d_head),
                           0] = BUVF_sorted[torch.tensor(p_w_d_head), 3]
        dummy_pred_features = dummy_pred_features[BUVF_sorted_idx]
        flag = flag[BUVF_sorted_idx]

        # deal with bg examples
        for b in range(pred_sp.batch_size):
            if len(example['voxel_dict'][pre_dict['dummy_batch_list'][b]]
                   ['gt']) == 0:
                b_bg_idx = BUVF_sorted[:, 0] == b
                label_features_ppc[b_bg_idx] = 0

        # calculate loss_ppc
        loss_ppc_L1 = torch.nn.L1Loss()
        loss_ppc = loss_ppc_L1(dummy_pred_features[~flag],
                               label_features_ppc[~flag])

        # total loss
        loss = self.lambda_F * loss_f + self.lambda_C * loss_ppc

        # information to be logged and displayed
        loss_info = {
            'gt_num': gt_num,
            'drop_gt_num': drop_gt_num,
            'pred_num': pred_num,
            'dropped_f_label': dropped_f_label,
            'flag_test': flag_test_dict,
            'loss': {
                'loss_total': loss,
                'loss_f': loss_f.item(),
                'loss_ppc': loss_ppc.item(),
            }
        }
        if label_indices.shape[0] > 0:
            loss_info['acc_f'] = thres_acc_dict
            loss_info['recall_f'] = thres_recall_dict

        return loss, loss_info
예제 #8
0
def create_inferenced_velodyne_data(dataset_cfg_path,
                                    dataset_section,
                                    model_cfg_path,
                                    model_path,
                                    pred_key,
                                    confidence,
                                    velodyne_path,
                                    output_path,
                                    batch_size=6,
                                    eval_flag=True):
    # get configurations
    dataset_cfg = ConfigParser()
    model_cfg = ConfigParser()
    dataset_cfg.read(dataset_cfg_path)
    model_cfg.read(model_cfg_path)

    # prepare dataset
    dataset = get_class(dataset_cfg[dataset_section]['class'])(dataset_cfg[dataset_section])
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=batch_size,
        pin_memory=False,
        collate_fn=get_class(dataset_cfg[dataset_section]['collate_fn']),
    )

    # prepare network model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net = LPCC_Net(model_cfg['MODEL']).to(device)
    state_dict = torch.load(model_path)
    net.load_state_dict(state_dict)
    if eval_flag:
        net.eval()

    # generate predictions
    print('generating predictions ...', flush=True)
    pred_dict = defaultdict(list)
    for example in tqdm(dataloader):
        # network feed-forward
        with torch.no_grad():
            res = net(example)
        ret_dict = res[-1]

        # get output confidence and indices
        out_features = torch.sigmoid(get_input(ret_dict, pred_key).features)
        out_indices = get_input(ret_dict, pred_key).indices
        filtered_indices = out_indices[out_features[:, 0] > confidence]

        # transform to points
        coors = indices_to_coors(filtered_indices, [0.1, 0.05, 0.05], [-3.0, -40.0, 0.0])[:, [3, 2, 1]]
        coors_w_r = torch.cat((coors, torch.zeros(coors.shape[0], 1, dtype=coors.dtype, device=coors.device)), dim=1)

        # deal with batches
        for batch in range(len(example['scene_idx'])):
            scene_idx = example['scene_idx'][batch]
            b_idx = filtered_indices[:, 0] == batch
            pred_dict[scene_idx].append(coors_w_r[b_idx].cpu().numpy())
    print('done.', flush=True)

    # merge with original velodyne data
    print('writing to file ...', flush=True)
    for s_idx in tqdm(pred_dict.keys()):
        scene_pts = read_bin(os.path.join(velodyne_path, str(s_idx).zfill(6) + '.bin'))
        merged_pts = np.concatenate([*pred_dict[s_idx], scene_pts], axis=0)
        with open(os.path.join(output_path, str(s_idx).zfill(6) + '.bin'), 'wb') as f:
            merged_pts.tofile(f)
    print('done.', flush=True)
예제 #9
0
    def loss(self, example, ff_ret_dict):
        # get final output sp tensor of the network
        pred_sp = get_input(ff_ret_dict, self.out_key)

        # get predicted indices and features (sigmoid function applied)
        pred_indices = pred_sp.indices
        pred_features = torch.sigmoid(pred_sp.features)

        # get true ground truth indices
        label_indices = get_input(ff_ret_dict, self.gt_key)

        # generate dummy indices and features for loss calculation
        indices_cated = torch.cat((pred_indices, label_indices), dim=0)
        dummy_pred_indices, i_indices_index = torch.unique(indices_cated,
                                                           dim=0,
                                                           return_inverse=True)
        dummy_pred_features = torch.ones(dummy_pred_indices.shape[0],
                                         1,
                                         dtype=pred_features.dtype,
                                         device=pred_features.device)
        dummy_pred_features[
            i_indices_index[:pred_indices.shape[0]]] = pred_features

        # generate label features (with same shape as dummy_pred_features)
        label_features = torch.zeros(dummy_pred_features.shape[0],
                                     1,
                                     device=dummy_pred_features.device)
        concatenated = torch.cat((dummy_pred_indices, label_indices), dim=0)
        uniqued, i_index = torch.unique(
            concatenated, dim=0, return_inverse=True
        )  # unique should be the same as dummy_pred_indices
        label_features[i_index[dummy_pred_indices.shape[0]:]] = 1.0

        # calculate focal_loss
        focal_loss = FocalLoss()
        loss = focal_loss(dummy_pred_features, label_features)

        # get stat.
        gt_num = label_indices.shape[0]
        drop_gt_num = dummy_pred_indices.shape[0] - pred_indices.shape[0]
        pred_num = pred_indices.shape[0]

        # calculate recall at different confidence thresholds
        recall = {
            '0.3':
            float((dummy_pred_features[i_index[dummy_pred_features.shape[0]:]]
                   > 0.3).sum().float() / gt_num),
            '0.5':
            float((dummy_pred_features[i_index[dummy_pred_features.shape[0]:]]
                   > 0.5).sum().float() / gt_num),
            '0.7':
            float((dummy_pred_features[i_index[dummy_pred_features.shape[0]:]]
                   > 0.7).sum().float() / gt_num),
            '0.9':
            float((dummy_pred_features[i_index[dummy_pred_features.shape[0]:]]
                   > 0.9).sum().float() / gt_num),
        }

        # information to be logged and displayed
        loss_info = {
            'gt_num': gt_num,
            'drop_gt_num': drop_gt_num,
            'pred_num': pred_num,
            'recall': recall,
        }

        return loss, loss_info