예제 #1
0
def vis_scene_data(nuscenes_path=None,
                   nuscenes_version='v1.0-trainval',
                   which_scene=0,
                   max_seq_num=10,
                   begin_frame=0,
                   frame_skip=3,
                   trained_model_path=None,
                   img_save_dir=None,
                   which_model='MotionNet',
                   use_adj_frame_pred=True,
                   use_motion_state_pred_masking=True,
                   disp=True):
    """
    Visualize the scene data.

    nuscenes_path: the path to the nuScenes dataset
    nuscenes_version: the dataset version ['v1.0-trainval'/'v1.0-mini']
    which_scene: for which we want to visualize
    max_seq_num: how many frames want to visualize
    begin_frame: for this scene, from which frame we want to start our prediction
    frame_skip: how many future frames we want to skip. This is used for generated preprocessed bev data.
    trained_model_path: the path to the pretrained model
    img_save_dir: the directory for saving the predicted image
    which_model: which network ['MotionNet'/'MotionNetMGDA']
    use_adj_frame_pred: whether to predict the relative offsets between two adjacent frames
    use_motion_state_pred_masking: whether to threshold the prediction with motion state estimation results
    disp: whether to immediately show the predicted results
    """
    if nuscenes_path is None:
        raise ValueError("Should specify the nuScenes data path.")

    nusc = NuScenes(version=nuscenes_version,
                    dataroot=nuscenes_path,
                    verbose=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    nsweeps_back = 20
    nsweeps_forward = 20
    num_frame_skipped = 0

    voxel_size = (0.25, 0.25, 0.4)
    area_extents = np.array([[-32., 32.], [-32., 32.], [-3, 2]])

    sample_cnt = 1
    class_map = {
        'vehicle.car': 1,
        'vehicle.bus.rigid': 1,
        'vehicle.bus.bendy': 1,
        'human.pedestrian': 2,
        'vehicle.bicycle': 3
    }  # background: 0, other: 4

    curr_scene = nusc.scene[which_scene]

    first_sample_token = curr_scene['first_sample_token']
    last_sample_token = curr_scene['last_sample_token']
    curr_sample = nusc.get('sample', first_sample_token)
    curr_sample_data = nusc.get('sample_data',
                                curr_sample['data']['LIDAR_TOP'])

    has_reached_last_keyframe = False
    seq_num = 0

    # Load pre-trained network weights
    loaded_models = list()
    if which_model == "MotionNet":
        model = MotionNet(out_seq_len=20,
                          motion_category_num=2,
                          height_feat_size=13)

        model = nn.DataParallel(model)
        checkpoint = torch.load(trained_model_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        model = model.to(device)

        loaded_models = [model]
    else:
        model_encoder = FeatEncoder()
        model_head = MotionNetMGDA(out_seq_len=20, motion_category_num=2)

        model_encoder = nn.DataParallel(model_encoder)
        model_head = nn.DataParallel(model_head)

        checkpoint = torch.load(trained_model_path)
        model_encoder.load_state_dict(checkpoint['encoder_state_dict'])
        model_head.load_state_dict(checkpoint['head_state_dict'])

        model_encoder = model_encoder.to(device)
        model_head = model_head.to(device)

        loaded_models = [model_encoder, model_head]
    print("Loaded pretrained model {}".format(which_model))

    while curr_sample_data['next'] != '':
        if has_reached_last_keyframe:
            break

        # has reached the final keyframe
        if curr_sample_data['token'] == last_sample_token:
            has_reached_last_keyframe = True

        # Skip current keyframe if possible
        if num_frame_skipped > 0 and sample_cnt % (num_frame_skipped + 1) == 0:
            sample_cnt += 1
            curr_sample_data = nusc.get('sample_data',
                                        curr_sample_data['next'])
            continue

        # Get the synchronized point clouds
        all_pc, all_times = LidarPointCloud.from_file_multisweep_bf_sample_data(
            nusc,
            curr_sample_data,
            nsweeps_back=nsweeps_back,
            nsweeps_forward=nsweeps_forward)

        # Store point cloud of each sweep
        pc = all_pc.points
        _, sort_idx = np.unique(all_times, return_index=True)
        unique_times = all_times[np.sort(
            sort_idx)]  # Preserve the item order in unique_times
        num_sweeps = len(unique_times)

        # Make sure we have sufficient past and future sweeps
        if num_sweeps != (nsweeps_back + nsweeps_forward):
            sample_cnt += 1
            curr_sample_data = nusc.get('sample_data',
                                        curr_sample_data['next'])
            continue

        # Prepare data dictionary for visualization
        save_data_dict = dict()

        for tid in range(num_sweeps):
            _time = unique_times[tid]
            points_idx = np.where(all_times == _time)[0]
            _pc = pc[:, points_idx]
            save_data_dict['pc_' + str(tid)] = _pc

        save_data_dict['times'] = unique_times
        save_data_dict['num_sweeps'] = num_sweeps

        # Get the synchronized bounding boxes
        # First, we need to iterate all the instances, and then retrieve their corresponding bounding boxes
        num_instances = 0  # The number of instances within this sample
        corresponding_sample_token = curr_sample_data['sample_token']
        corresponding_sample_rec = nusc.get('sample',
                                            corresponding_sample_token)

        for ann_token in corresponding_sample_rec['anns']:
            ann_rec = nusc.get('sample_annotation', ann_token)
            category_name = ann_rec['category_name']

            flag = False
            for c, v in class_map.items():
                if category_name.startswith(c):
                    save_data_dict['category_' + str(num_instances)] = v
                    flag = True
                    break
            if not flag:
                save_data_dict['category_' +
                               str(num_instances)] = 4  # Other category

            instance_token = ann_rec['instance_token']

            instance_boxes, instance_all_times, _, _ = LidarPointCloud. \
                get_instance_boxes_multisweep_sample_data(nusc, curr_sample_data,
                                                          instance_token,
                                                          nsweeps_back=nsweeps_back,
                                                          nsweeps_forward=nsweeps_forward)

            assert np.array_equal(
                unique_times, instance_all_times
            ), "The sweep and instance times are not consistent!"
            assert num_sweeps == len(
                instance_boxes
            ), "The number of instance boxes does not match that of sweeps!"

            # Each row corresponds to a box annotation; the column consists of box center, box size, and quaternion
            box_data = np.zeros((len(instance_boxes), 3 + 3 + 4),
                                dtype=np.float32)
            box_data.fill(np.nan)
            for r, box in enumerate(instance_boxes):
                if box is not None:
                    row = np.concatenate(
                        [box.center, box.wlh, box.orientation.elements])
                    box_data[r] = row[:]

            # Save the box data for current instance
            save_data_dict['instance_boxes_' + str(num_instances)] = box_data
            num_instances += 1

        save_data_dict['num_instances'] = num_instances

        if seq_num < begin_frame:
            seq_num += 1
            sample_cnt += 1
            print("Finish loading sequence sample {}".format(seq_num))
            continue

        # ------------------------------------ Visualization ------------------------------------
        # -- The following code is simply borrowed from gen_data.py and currently not optimized
        num_sweeps = save_data_dict['num_sweeps']
        times = save_data_dict['times']
        num_past_sweeps = len(np.where(times >= 0)[0])
        num_future_sweeps = len(np.where(times < 0)[0])
        assert num_past_sweeps + num_future_sweeps == num_sweeps, "The number of sweeps is incorrect!"

        # Load point cloud
        pc_list = []

        for i in range(num_sweeps):
            pc = save_data_dict['pc_' + str(i)]
            pc_list.append(pc.T)

        # Reorder the pc, and skip sample frames if wanted
        tmp_pc_list_1 = pc_list[0:num_past_sweeps:(frame_skip + 1)]
        tmp_pc_list_1 = tmp_pc_list_1[::-1]
        tmp_pc_list_2 = pc_list[(num_past_sweeps + frame_skip)::(frame_skip +
                                                                 1)]
        pc_list = tmp_pc_list_1 + tmp_pc_list_2

        num_past_pcs = len(tmp_pc_list_1)
        num_future_pcs = len(tmp_pc_list_2)

        # Voxelize the input point clouds, and compute the ground truth displacement vectors
        padded_voxel_points_list = list(
        )  # This contains the compact representation of voxelization, as in the paper

        for i in range(num_past_pcs):
            res = voxelize_occupy(pc_list[i],
                                  voxel_size=voxel_size,
                                  extents=area_extents)
            padded_voxel_points_list.append(res)

        # Compile the batch of voxels, so that they can be fed into the network
        padded_voxel_points = torch.from_numpy(
            np.stack(padded_voxel_points_list, axis=0))

        # Finally, generate the ground-truth displacement field
        all_disp_field_gt, all_valid_pixel_maps, non_empty_map, pixel_cat_map \
            = gen_2d_grid_gt_for_visualization(save_data_dict, grid_size=voxel_size[0:2], reordered=True,
                                               extents=area_extents, frame_skip=frame_skip)

        bev_input_data = (padded_voxel_points, all_disp_field_gt,
                          all_valid_pixel_maps, non_empty_map, pixel_cat_map,
                          num_past_pcs, num_future_pcs)

        vis_model_per_sample_data(
            bev_input_data,
            save_data_dict,
            frame_skip=frame_skip,
            loaded_models=loaded_models,
            voxel_size=voxel_size,
            which_model=which_model,
            model_path=trained_model_path,
            img_save_dir=img_save_dir,
            use_adj_frame_pred=use_adj_frame_pred,
            disp=disp,
            use_motion_state_pred_masking=use_motion_state_pred_masking,
            frame_idx=seq_num)

        seq_num += 1
        print("Finish loading sequence sample {}".format(seq_num))

        sample_cnt += 1
        curr_sample_data = nusc.get('sample_data', curr_sample_data['next'])

        if seq_num - begin_frame >= max_seq_num:
            break
예제 #2
0
def gen_data():
    res_scenes = list()
    for s in scenes:
        s_id = s.split('_')[1]
        res_scenes.append(int(s_id))

    for scene_idx in res_scenes:
        curr_scene = nusc.scene[scene_idx]

        first_sample_token = curr_scene['first_sample_token']
        curr_sample = nusc.get('sample', first_sample_token)
        curr_sample_data = nusc.get('sample_data',
                                    curr_sample['data']['LIDAR_TOP'])

        save_data_dict_list = list(
        )  # for storing consecutive sequences; the data consists of timestamps, points, etc
        save_box_dict_list = list(
        )  # for storing box annotations in consecutive sequences
        save_instance_token_list = list()
        adj_seq_cnt = 0
        save_seq_cnt = 0  # only used for save data file name

        # Iterate each sample data
        print("Processing scene {} ...".format(scene_idx))
        while curr_sample_data['next'] != '':

            # Get the synchronized point clouds
            all_pc, all_times, trans_matrices = \
                LidarPointCloud.from_file_multisweep_bf_sample_data(nusc, curr_sample_data,
                                                                    return_trans_matrix=True,
                                                                    nsweeps_back=nsweeps_back,
                                                                    nsweeps_forward=nsweeps_forward)
            # Store point cloud of each sweep
            pc = all_pc.points
            _, sort_idx = np.unique(all_times, return_index=True)
            unique_times = all_times[np.sort(
                sort_idx)]  # Preserve the item order in unique_times
            num_sweeps = len(unique_times)

            # Make sure we have sufficient past and future sweeps
            if num_sweeps != (nsweeps_back + nsweeps_forward):

                # Skip some keyframes if necessary
                flag = False
                for _ in range(num_keyframe_skipped + 1):
                    if curr_sample['next'] != '':
                        curr_sample = nusc.get('sample', curr_sample['next'])
                    else:
                        flag = True
                        break

                if flag:  # No more keyframes
                    break
                else:
                    curr_sample_data = nusc.get(
                        'sample_data', curr_sample['data']['LIDAR_TOP'])

                # Reset
                adj_seq_cnt = 0
                save_data_dict_list = list()
                save_box_dict_list = list()
                save_instance_token_list = list()
                continue

            # Prepare data dictionary for the next step (ie, generating BEV maps)
            save_data_dict = dict()
            box_data_dict = dict(
            )  # for remapping the instance ids, according to class_map
            curr_token_list = list()

            for tid in range(num_sweeps):
                _time = unique_times[tid]
                points_idx = np.where(all_times == _time)[0]
                _pc = pc[:, points_idx]
                save_data_dict['pc_' + str(tid)] = _pc

            save_data_dict['times'] = unique_times
            save_data_dict['num_sweeps'] = num_sweeps
            save_data_dict['trans_matrices'] = trans_matrices

            # Get the synchronized bounding boxes
            # First, we need to iterate all the instances, and then retrieve their corresponding bounding boxes
            num_instances = 0  # The number of instances within this sample
            corresponding_sample_token = curr_sample_data['sample_token']
            corresponding_sample_rec = nusc.get('sample',
                                                corresponding_sample_token)

            for ann_token in corresponding_sample_rec['anns']:
                ann_rec = nusc.get('sample_annotation', ann_token)
                category_name = ann_rec['category_name']
                instance_token = ann_rec['instance_token']

                flag = False
                for c, v in class_map.items():
                    if category_name.startswith(c):
                        box_data_dict['category_' + instance_token] = v
                        flag = True
                        break
                if not flag:
                    box_data_dict['category_' +
                                  instance_token] = 4  # Other category

                instance_boxes, instance_all_times, _, _ = LidarPointCloud. \
                    get_instance_boxes_multisweep_sample_data(nusc, curr_sample_data,
                                                              instance_token,
                                                              nsweeps_back=nsweeps_back,
                                                              nsweeps_forward=nsweeps_forward)

                assert np.array_equal(
                    unique_times, instance_all_times
                ), "The sweep and instance times are inconsistent!"
                assert num_sweeps == len(
                    instance_boxes
                ), "The number of instance boxes does not match that of sweeps!"

                # Each row corresponds to a box annotation; the column consists of box center, box size, and quaternion
                box_data = np.zeros((len(instance_boxes), 3 + 3 + 4),
                                    dtype=np.float32)
                box_data.fill(np.nan)
                for r, box in enumerate(instance_boxes):
                    if box is not None:
                        row = np.concatenate(
                            [box.center, box.wlh, box.orientation.elements])
                        box_data[r] = row[:]

                # Save the box data for current instance
                box_data_dict['instance_boxes_' + instance_token] = box_data
                num_instances += 1

                curr_token_list.append(instance_token)

            save_data_dict['num_instances'] = num_instances
            save_data_dict_list.append(save_data_dict)
            save_box_dict_list.append(box_data_dict)
            save_instance_token_list.append(curr_token_list)

            # Update the counter and save the data if desired (But here we do not want to
            # save the data to disk since it would cost about 2TB space)
            adj_seq_cnt += 1
            if adj_seq_cnt == num_adj_seqs:

                # First, we need to reorganize the instance tokens (ids)
                num_instance_token_list = len(save_instance_token_list)
                if num_instance_token_list > 1:
                    common_tokens = set(
                        save_instance_token_list[0]).intersection(
                            save_instance_token_list[1])

                    for l in range(2, num_instance_token_list):
                        common_tokens = common_tokens.intersection(
                            save_instance_token_list[l])

                    for l in range(num_instance_token_list):
                        exclusive_tokens = set(
                            save_instance_token_list[l]).difference(
                                common_tokens)

                        # we store the common instances first, then store the remaining instances
                        curr_save_data_dict = save_data_dict_list[l]
                        curr_save_box_dict = save_box_dict_list[l]
                        counter = 0
                        for token in common_tokens:
                            box_info = curr_save_box_dict['instance_boxes_' +
                                                          token]
                            box_cat = curr_save_box_dict['category_' + token]

                            curr_save_data_dict['instance_boxes_' +
                                                str(counter)] = box_info
                            curr_save_data_dict['category_' +
                                                str(counter)] = box_cat

                            counter += 1

                        for token in exclusive_tokens:
                            box_info = curr_save_box_dict['instance_boxes_' +
                                                          token]
                            box_cat = curr_save_box_dict['category_' + token]

                            curr_save_data_dict['instance_boxes_' +
                                                str(counter)] = box_info
                            curr_save_data_dict['category_' +
                                                str(counter)] = box_cat

                            counter += 1

                        assert counter == curr_save_data_dict[
                            'num_instances'], "The number of instances is inconsistent."

                        save_data_dict_list[l] = curr_save_data_dict
                else:
                    curr_save_box_dict = save_box_dict_list[0]
                    curr_save_data_dict = save_data_dict_list[0]
                    for index, token in enumerate(save_instance_token_list[0]):
                        box_info = curr_save_box_dict['instance_boxes_' +
                                                      token]
                        box_cat = curr_save_box_dict['category_' + token]

                        curr_save_data_dict['instance_boxes_' +
                                            str(index)] = box_info
                        curr_save_data_dict['category_' + str(index)] = box_cat

                    save_data_dict_list[0] = curr_save_data_dict

                # ------------------------ Now we generate dense BEV maps ------------------------
                for seq_idx, seq_data_dict in enumerate(save_data_dict_list):
                    dense_bev_data = convert_to_dense_bev(seq_data_dict)
                    sparse_bev_data = convert_to_sparse_bev(dense_bev_data)

                    # save the data
                    save_directory = os.path.join(
                        args.savepath,
                        str(scene_idx) + '_' + str(save_seq_cnt))
                    os.makedirs(save_directory, exist_ok=True)
                    save_file_name = os.path.join(save_directory,
                                                  str(seq_idx) + '.npy')
                    np.save(save_file_name, arr=sparse_bev_data)

                    print("  >> Finish sample: {}, sequence {}".format(
                        save_seq_cnt, seq_idx))
                # --------------------------------------------------------------------------------

                save_seq_cnt += 1
                adj_seq_cnt = 0
                save_data_dict_list = list()
                save_box_dict_list = list()
                save_instance_token_list = list()

                # Skip some keyframes if necessary
                flag = False
                for _ in range(num_keyframe_skipped + 1):
                    if curr_sample['next'] != '':
                        curr_sample = nusc.get('sample', curr_sample['next'])
                    else:
                        flag = True
                        break

                if flag:  # No more keyframes
                    break
                else:
                    curr_sample_data = nusc.get(
                        'sample_data', curr_sample['data']['LIDAR_TOP'])
            else:
                flag = False
                for _ in range(skip_frame + 1):
                    if curr_sample_data['next'] != '':
                        curr_sample_data = nusc.get('sample_data',
                                                    curr_sample_data['next'])
                    else:
                        flag = True
                        break

                if flag:  # No more sample frames
                    break