def random(cls, rng=None, **kwargs): """ Args: rng (None | int | numpy.random.RandomState): seed or state Kwargs: num_preds: number of predicted boxes num_gts: number of true boxes p_ignore (float): probability of a predicted box assinged to an ignored truth p_assigned (float): probability of a predicted box not being assigned p_use_label (float | bool): with labels or not Returns: AssignResult : Example: >>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA >>> self = SamplingResult.random() >>> print(self.__dict__) """ from mmdet.core.bbox.samplers.random_sampler import RandomSampler from mmdet.core.bbox.assigners.assign_result import AssignResult from mmdet.core.bbox import demodata rng = demodata.ensure_rng(rng) # make probabalistic? num = 32 pos_fraction = 0.5 neg_pos_ub = -1 assign_result = AssignResult.random(rng=rng, **kwargs) # Note we could just compute an assignment bboxes = demodata.random_boxes(assign_result.num_preds, rng=rng) gt_bboxes = demodata.random_boxes(assign_result.num_gts, rng=rng) if rng.rand() > 0.2: # sometimes algorithms squeeze their data, be robust to that gt_bboxes = gt_bboxes.squeeze() bboxes = bboxes.squeeze() if assign_result.labels is None: gt_labels = None else: gt_labels = None # todo if gt_labels is None: add_gt_as_proposals = False else: add_gt_as_proposals = True # make probabalistic? sampler = RandomSampler(num, pos_fraction, neg_pos_ubo=neg_pos_ub, add_gt_as_proposals=add_gt_as_proposals, rng=rng) self = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels) return self
def test_random_assign_result(): """Test random instantiation of assign result to catch corner cases.""" from mmdet.core.bbox.assigners.assign_result import AssignResult AssignResult.random() AssignResult.random(num_gts=0, num_preds=0) AssignResult.random(num_gts=0, num_preds=3) AssignResult.random(num_gts=3, num_preds=3) AssignResult.random(num_gts=0, num_preds=3) AssignResult.random(num_gts=7, num_preds=7) AssignResult.random(num_gts=7, num_preds=64) AssignResult.random(num_gts=24, num_preds=3)