Exemplo n.º 1
0
    def __init__(
        self,
        db_infos,  # object/dbinfos_train.pkl
        groups,  # [dict(Car=15,),],
        db_prepor=None,  # filter_by_min_num_points, filter_by_difficulty
        rate=1.0,  # rate=1.0
        global_rot_range=None,  # [0, 0]
        logger=None,  # logging.getLogger("build_dbsampler")
        gt_random_drop=-1.0,
        gt_aug_with_context=-1.0,
        gt_aug_similar_type=False,
    ):
        # load all gt database here.
        for k, v in db_infos.items():
            logger.info(f"load {len(v)} {k} database infos")

        # preprocess: filter_by_min_num_points/difficulty.
        if db_prepor is not None:
            db_infos = db_prepor(db_infos)
            logger.info("After filter database:")
            for k, v in db_infos.items():
                logger.info(f"load {len(v)} {k} database infos")

        self.db_infos = db_infos
        self._rate = rate
        self._groups = groups
        self._group_db_infos = {}
        self._group_name_to_names = []
        self._sample_classes = []
        self._sample_max_nums = []
        self.gt_point_random_drop = gt_random_drop
        self.gt_aug_with_context = gt_aug_with_context

        # get group_name: Car and group_max_num: 15
        self._group_db_infos = self.db_infos  # just use db_infos
        for group_info in groups:
            self._sample_classes += list(group_info.keys())  # ['Car']
            self._sample_max_nums += list(group_info.values())  # [15]

        # get sampler dict for each class like Car, Cyclist, Pedestrian...
        # this sampler can ensure batch samples selected randomly.
        self._sampler_dict = {}
        for k, v in self._group_db_infos.items():
            self._sampler_dict[k] = prep.BatchSampler(v, k)

        if gt_aug_similar_type:
            self._sampler_dict["Car"] = prep.BatchSampler(
                self._group_db_infos["Car"] + self._group_db_infos["Van"],
                "Car")
Exemplo n.º 2
0
    def __init__(
        self,
        db_infos,
        groups,
        db_prepor=None,
        rate=1.0,
        global_rot_range=None,
        logger=None,
    ):
        for k, v in db_infos.items():
            logger.info(f"load {len(v)} {k} database infos")

        if db_prepor is not None:
            db_infos = db_prepor(db_infos)
            logger.info("After filter database:")
            for k, v in db_infos.items():
                logger.info(f"load {len(v)} {k} database infos")

        self.db_infos = db_infos
        self._rate = rate
        self._groups = groups
        self._group_db_infos = {}
        self._group_name_to_names = []
        self._sample_classes = []
        self._sample_max_nums = []
        self._use_group_sampling = False  # slower
        if any([len(g) > 1 for g in groups]):
            self._use_group_sampling = True
        if not self._use_group_sampling:
            self._group_db_infos = self.db_infos  # just use db_infos
            for group_info in groups:
                group_names = list(group_info.keys())
                self._sample_classes += group_names
                self._sample_max_nums += list(group_info.values())
        else:
            for group_info in groups:
                group_dict = {}
                group_names = list(group_info.keys())
                group_name = ", ".join(group_names)
                self._sample_classes += group_names
                self._sample_max_nums += list(group_info.values())
                self._group_name_to_names.append((group_name, group_names))
                # self._group_name_to_names[group_name] = group_names
                for name in group_names:
                    for item in db_infos[name]:
                        gid = item["group_id"]
                        if gid not in group_dict:
                            group_dict[gid] = [item]
                        else:
                            group_dict[gid] += [item]
                if group_name in self._group_db_infos:
                    raise ValueError("group must be unique")
                group_data = list(group_dict.values())
                self._group_db_infos[group_name] = group_data
                info_dict = {}
                if len(group_info) > 1:
                    for group in group_data:
                        names = [item["name"] for item in group]
                        names = sorted(names)
                        group_name = ", ".join(names)
                        if group_name in info_dict:
                            info_dict[group_name] += 1
                        else:
                            info_dict[group_name] = 1
                print(info_dict)

        self._sampler_dict = {}
        for k, v in self._group_db_infos.items():
            self._sampler_dict[k] = prep.BatchSampler(v, k)
        self._enable_global_rot = False
        if global_rot_range is not None:
            if not isinstance(global_rot_range, (list, tuple, np.ndarray)):
                global_rot_range = [-global_rot_range, global_rot_range]
            else:
                assert shape_mergeable(global_rot_range, [2])
            if np.abs(global_rot_range[0] - global_rot_range[1]) >= 1e-3:
                self._enable_global_rot = True
        self._global_rot_range = global_rot_range
Exemplo n.º 3
0
    def __init__(
            self,
            db_infos,  # object/dbinfos_train.pkl
            groups,  # [dict(Car=15,),],
            db_prepor=None,  # filter_by_min_num_points, filter_by_difficulty
            rate=1.0,  # rate=1.0
            global_rot_range=None,  # [0, 0]
            logger=None,  # logging.getLogger("build_dbsampler")
    ):
        # load all gt database here.
        for k, v in db_infos.items():
            logger.info(f"load {len(v)} {k} database infos")

        # preprocess: filter_by_min_num_points/difficulty.
        if db_prepor is not None:
            db_infos = db_prepor(db_infos)
            logger.info("After filter database:")
            for k, v in db_infos.items():
                logger.info(f"load {len(v)} {k} database infos")

        self.db_infos = db_infos
        self._rate = rate
        self._groups = groups
        self._group_db_infos = {}
        self._group_name_to_names = []
        self._sample_classes = []
        self._sample_max_nums = []
        self._use_group_sampling = False  # slower

        if any([len(g) > 1 for g in groups]):  # False
            self._use_group_sampling = True

        # get group_name: Car and group_max_num: 15
        if not self._use_group_sampling:  # True
            self._group_db_infos = self.db_infos  # just use db_infos
            for group_info in groups:
                group_names = list(group_info.keys())
                self._sample_classes += group_names  # ['Car']
                self._sample_max_nums += list(group_info.values())  # [15]
        else:  # False
            for group_info in groups:
                group_dict = {}
                group_names = list(group_info.keys())
                group_name = ", ".join(group_names)
                self._sample_classes += group_names
                self._sample_max_nums += list(group_info.values())
                self._group_name_to_names.append((group_name, group_names))
                # self._group_name_to_names[group_name] = group_names
                for name in group_names:
                    for item in db_infos[name]:
                        gid = item["group_id"]
                        if gid not in group_dict:
                            group_dict[gid] = [item]
                        else:
                            group_dict[gid] += [item]
                if group_name in self._group_db_infos:
                    raise ValueError("group must be unique")
                group_data = list(group_dict.values())
                self._group_db_infos[group_name] = group_data
                info_dict = {}
                if len(group_info) > 1:
                    for group in group_data:
                        names = [item["name"] for item in group]
                        names = sorted(names)
                        group_name = ", ".join(names)
                        if group_name in info_dict:
                            info_dict[group_name] += 1
                        else:
                            info_dict[group_name] = 1
                print(info_dict)

        # get sampler dict for each class like Car, Cyclist, Pedestrian...
        # this sampler can ensure batch samples selected randomly.
        self._sampler_dict = {}
        for k, v in self._group_db_infos.items():
            self._sampler_dict[k] = prep.BatchSampler(v, k)

        # get rotation range
        self._enable_global_rot = False
        if global_rot_range is not None:
            if not isinstance(global_rot_range, (list, tuple, np.ndarray)):
                global_rot_range = [-global_rot_range, global_rot_range]
            else:  # True
                assert shape_mergeable(global_rot_range, [2])  # True
            if np.abs(global_rot_range[0] -
                      global_rot_range[1]) >= 1e-3:  # False
                self._enable_global_rot = True
        self._global_rot_range = global_rot_range  # [0, 0]