def test_dbnet_draw_border_map(): target_generator = textdet_targets.DBNetTargets() poly = np.array([[20, 21], [-14, 20], [-11, 30], [-22, 26]]) img_size = (40, 40) thr_map = np.zeros(img_size, dtype=np.float32) thr_mask = np.zeros(img_size, dtype=np.uint8) target_generator.draw_border_map(poly, thr_map, thr_mask)
def test_dbnet_targets_find_invalid(): target_generator = textdet_targets.DBNetTargets() assert target_generator.shrink_ratio == 0.4 assert target_generator.thr_min == 0.3 assert target_generator.thr_max == 0.7 results = {} text_polys = [[np.array([0, 0, 10, 0, 10, 10, 0, 10])], [np.array([20, 0, 30, 0, 30, 10, 20, 10])]] results['gt_masks'] = PolygonMasks(text_polys, 40, 40) ignore_tags = target_generator.find_invalid(results) assert np.allclose(ignore_tags, [False, False])
def test_dbnet_generate_targets(): target_generator = textdet_targets.DBNetTargets() text_polys = [[np.array([0, 0, 10, 0, 10, 10, 0, 10])], [np.array([20, 0, 30, 0, 30, 10, 20, 10])]] text_polys_ignore = [[np.array([0, 0, 15, 0, 15, 10, 0, 10])]] results = {} results['mask_fields'] = [] results['img_shape'] = (40, 40, 3) results['gt_masks_ignore'] = PolygonMasks(text_polys_ignore, 40, 40) results['gt_masks'] = PolygonMasks(text_polys, 40, 40) results['gt_bboxes'] = np.array([[0, 0, 10, 10], [20, 0, 30, 10]]) results['gt_labels'] = np.array([0, 1]) target_generator.generate_targets(results) assert 'gt_shrink' in results['mask_fields'] assert 'gt_shrink_mask' in results['mask_fields'] assert 'gt_thr' in results['mask_fields'] assert 'gt_thr_mask' in results['mask_fields']
def test_dbnet_ignore_texts(): target_generator = textdet_targets.DBNetTargets() ignore_tags = [True, False] results = {} text_polys = [[np.array([0, 0, 10, 0, 10, 10, 0, 10])], [np.array([20, 0, 30, 0, 30, 10, 20, 10])]] text_polys_ignore = [[np.array([0, 0, 15, 0, 15, 10, 0, 10])]] results['gt_masks_ignore'] = PolygonMasks(text_polys_ignore, 40, 40) results['gt_masks'] = PolygonMasks(text_polys, 40, 40) results['gt_bboxes'] = np.array([[0, 0, 10, 10], [20, 0, 30, 10]]) results['gt_labels'] = np.array([0, 1]) target_generator.ignore_texts(results, ignore_tags) assert np.allclose(results['gt_labels'], np.array([1])) assert len(results['gt_masks_ignore'].masks) == 2 assert np.allclose(results['gt_masks_ignore'].masks[1][0], text_polys[0][0]) assert len(results['gt_masks'].masks) == 1
def test_dbnet_targets(): target_generator = textdet_targets.DBNetTargets() assert target_generator.shrink_ratio == 0.4 assert target_generator.thr_min == 0.3 assert target_generator.thr_max == 0.7
def test_dbnet_generate_thr_map(): target_generator = textdet_targets.DBNetTargets() text_polys = [[np.array([0, 0, 10, 0, 10, 10, 0, 10])], [np.array([20, 0, 30, 0, 30, 10, 20, 10])]] thr_map, thr_mask = target_generator.generate_thr_map((40, 40), text_polys) assert np.all((thr_map >= 0.29) * (thr_map <= 0.71))