def get_tp_loader(data_dir, name, batch_size, parallel=None): isTrain = name == 'train' augmentors = get_tp_augmentor(isTrain) if data_dir.endswith('lmdb'): # 500000[70:87:20, 1.95it/s] data_dir = os.path.join(data_dir, 'ILSVRC-%s.lmdb' % name) ds = LMDBSerializer.load(data_dir, shuffle=False) ds = get_sequential_loader(ds, isTrain, batch_size, augmentors, parallel) else: # 500000[27:11:03, 5.11it/s] if isTrain: ds = dataset.ILSVRC12(data_dir, name, shuffle=True) else: ds = dataset.ILSVRC12Files(data_dir, name, shuffle=False) ds = get_random_loader(ds, isTrain, batch_size, augmentors, parallel) return ds
def get_infer_iterator(dataset, hparams, lmdb_path): serialize_to_lmdb(dataset, hparams, lmdb_path) batch_size = hparams.infer_batch_size num_gpu = hparams.num_gpu df = LMDBSerializer.load(lmdb_path, shuffle=False) batched_df = BatchData(df, batch_size=batch_size, remainder=False) splitted_df = MapData( batched_df, lambda x: [np.array_split(x[idx], num_gpu) for idx in range(len(x))]) prefetched_df = PrefetchDataZMQ(splitted_df, nr_proc=1, hwm=batch_size * 10) return prefetched_df
def get_iterator(hparams, dataset, lmdb_path, shuffle=True, drop_remainder=True, nr_proc=4): serialize_to_lmdb(hparams, dataset, lmdb_path) batch_size = hparams.batch_size num_gpu = hparams.num_gpu df = LMDBSerializer.load(lmdb_path, shuffle=shuffle) batched_df = BatchData(df, batch_size=batch_size, remainder=not drop_remainder) splitted_df = MapData( batched_df, lambda x: [np.array_split(x[idx], num_gpu) for idx in range(len(x))]) prefetched_df = PrefetchDataZMQ(splitted_df, nr_proc=nr_proc, hwm=batch_size * 10) return prefetched_df
imagenet_path = os.environ['IMAGENET'] for name in ['train', 'val']: # ['test'] ds0 = BinaryILSVRC12(imagenet_path, name) ds1 = MultiProcessRunnerZMQ(ds0, nr_proc=1) # dftools.dump_dataflow_to_lmdb(ds1, os.path.join(imagenet_path,'ILSVRC-%s.lmdb'%name)) if args.n == 1: paths = [os.path.join(imagenet_path,'ILSVRC-%s.lmdb'%name)] else: paths = [ os.path.join(imagenet_path,'ILSVRC-%s-%d.lmdb'%(name, i)) for i in range(args.n) ] if not args.check_only: if args.n == 1: LMDBSerializer.save(ds1, paths[0]) else: print("Saving to %d files:\n%s\n" %(args.n, "\n".join(paths))) LMDBSplitSaver.save(ds1, paths, args.n) orig_total_img_count = len(ds0) lmdb_total_img_count = 0 for i in range(args.n): ds = LMDBSerializer.load(paths[i], shuffle=False) lmdb_total_img_count += len(ds) print("'%s' orig: %d, lmdb: %d." %(name, orig_total_img_count, lmdb_total_img_count), end=' ') if orig_total_img_count != lmdb_total_img_count: print("Mismatch!") pdb.set_trace() else: print("Matched!")