def test_create(self): """ Test create a reader using my source """ def _my_data_reader(): mydata = build_source(self.rcnn_conf['DATA']['TRAIN']) for i, sample in enumerate(mydata): yield sample my_source = IteratorSource(_my_data_reader) mode = 'TRAIN' train_rd = Reader.create(mode, self.rcnn_conf['DATA'][mode], self.rcnn_conf['TRANSFORM'][mode], max_iter=10, my_source=my_source) out = None for sample in train_rd(): out = sample self.assertTrue(sample is not None) self.assertEqual(out[0][0].shape[0], 3) self.assertEqual(out[0][1].shape[0], 3) self.assertEqual(out[0][3].shape[1], 4) self.assertEqual(out[0][4].shape[1], 1) self.assertEqual(out[0][5].shape[1], 1)
def create_reader(feed, max_iter=0, args_path=None, my_source=None): """ Return iterable data reader. Args: max_iter (int): number of iterations. my_source (callable): callable function to create a source iterator which is used to provide source data in 'ppdet.data.reader' """ # if `DATASET_DIR` does not exists, search ~/.paddle/dataset for a directory # named `DATASET_DIR` (e.g., coco, pascal), if not present either, download data_config = _prepare_data_config(feed, args_path) bufsize = getattr(feed, 'bufsize', 10) use_process = getattr(feed, 'use_process', False) memsize = getattr(feed, 'memsize', '3G') transform_config = { 'WORKER_CONF': { 'bufsize': bufsize, 'worker_num': feed.num_workers, 'use_process': use_process, 'memsize': memsize }, 'BATCH_SIZE': feed.batch_size, 'DROP_LAST': feed.drop_last, 'USE_PADDED_IM_INFO': feed.use_padded_im_info, } batch_transforms = feed.batch_transforms pad = [t for t in batch_transforms if isinstance(t, PadBatch)] rand_shape = [t for t in batch_transforms if isinstance(t, RandomShape)] multi_scale = [t for t in batch_transforms if isinstance(t, MultiScale)] pad_ms_test = [t for t in batch_transforms if isinstance(t, PadMSTest)] if any(pad): transform_config['IS_PADDING'] = True if pad[0].pad_to_stride != 0: transform_config['COARSEST_STRIDE'] = pad[0].pad_to_stride if any(rand_shape): transform_config['RANDOM_SHAPES'] = rand_shape[0].sizes if any(multi_scale): transform_config['MULTI_SCALES'] = multi_scale[0].scales if any(pad_ms_test): transform_config['ENABLE_MULTISCALE_TEST'] = True transform_config['NUM_SCALE'] = feed.num_scale transform_config['COARSEST_STRIDE'] = pad_ms_test[0].pad_to_stride if hasattr(inspect, 'getfullargspec'): argspec = inspect.getfullargspec else: argspec = inspect.getargspec ops = [] for op in feed.sample_transforms: op_dict = op.__dict__.copy() argnames = [ arg for arg in argspec(type(op).__init__).args if arg != 'self' ] op_dict = {k: v for k, v in op_dict.items() if k in argnames} op_dict['op'] = op.__class__.__name__ ops.append(op_dict) transform_config['OPS'] = ops return Reader.create(feed.mode, data_config, transform_config, max_iter, my_source)