Ejemplo n.º 1
0
    def test_loader(self):
        coco_loader = VOCDataSet(
            dataset_dir=self.image_dir,
            image_dir=self.root_path,
            anno_path=self.anno_path,
            sample_num=240,
            use_default_label=False,
            label_list='/path/to/your/fl_fruit/label_list.txt')
        sample_trans = [
            DecodeImage(to_rgb=True),
            RandomFlipImage(),
            NormalizeImage(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225],
                           is_scale=True,
                           is_channel_first=False),
            ResizeImage(target_size=800, max_size=1333, interp=1),
            Permute(to_bgr=False)
        ]
        batch_trans = [
            PadBatch(pad_to_stride=32, use_padded_im_info=True),
        ]

        inputs_def = {
            'fields':
            ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd'],
        }
        data_loader = Reader(coco_loader,
                             sample_transforms=sample_trans,
                             batch_transforms=batch_trans,
                             batch_size=1,
                             shuffle=True,
                             drop_empty=True,
                             inputs_def=inputs_def)()

        return data_loader
    def test_loader_multi_threads(self):
        coco_loader = COCODataSet(dataset_dir=self.root_path,
                                  image_dir=self.image_dir,
                                  anno_path=self.anno_path,
                                  sample_num=10)
        sample_trans = [
            DecodeImage(to_rgb=True),
            ResizeImage(target_size=800, max_size=1333, interp=1),
            Permute(to_bgr=False)
        ]
        batch_trans = [
            PadBatch(pad_to_stride=32, use_padded_im_info=True),
        ]

        inputs_def = {
            'fields': [
                'image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd',
                'gt_mask'
            ],
        }
        data_loader = Reader(coco_loader,
                             sample_transforms=sample_trans,
                             batch_transforms=batch_trans,
                             batch_size=2,
                             shuffle=True,
                             drop_empty=True,
                             worker_num=2,
                             use_process=False,
                             bufsize=8,
                             inputs_def=inputs_def)()
        for i in range(2):
            for samples in data_loader:
                for sample in samples:
                    im_shape = sample[0].shape
                    self.assertEqual(im_shape[0], 3)
                    self.assertEqual(im_shape[1] % 32, 0)
                    self.assertEqual(im_shape[2] % 32, 0)

                    im_info_shape = sample[1].shape
                    self.assertEqual(im_info_shape[-1], 3)

                    im_id_shape = sample[2].shape
                    self.assertEqual(im_id_shape[-1], 1)

                    gt_bbox_shape = sample[3].shape
                    self.assertEqual(gt_bbox_shape[-1], 4)

                    gt_class_shape = sample[4].shape
                    self.assertEqual(gt_class_shape[-1], 1)
                    self.assertEqual(gt_class_shape[0], gt_bbox_shape[0])

                    is_crowd_shape = sample[5].shape
                    self.assertEqual(is_crowd_shape[-1], 1)
                    self.assertEqual(is_crowd_shape[0], gt_bbox_shape[0])

                    mask = sample[6]
                    self.assertEqual(len(mask), gt_bbox_shape[0])
                    self.assertEqual(mask[0][0].shape[-1], 2)
            data_loader.reset()