Пример #1
0
    def __init__(self,
                 mode,
                 batch_size=256,
                 shuffle=False,
                 num_workers=25,
                 cache=50000,
                 collate_fn=default_collate,
                 drop_last=False,
                 cuda=False):
        # enumerate standard imagenet augmentors
        imagenet_augmentors = fbresnet_augmentor(mode == 'train')

        # load the lmdb if we can find it
        lmdb_loc = os.path.join(os.environ['IMAGENET'],
                                'ILSVRC-%s.lmdb' % mode)
        ds = td.LMDBData(lmdb_loc, shuffle=False)
        ds = td.LocallyShuffleData(ds, cache)
        ds = td.PrefetchData(ds, 5000, 1)
        ds = td.LMDBDataPoint(ds)
        ds = td.MapDataComponent(ds,
                                 lambda x: cv2.imdecode(x, cv2.IMREAD_COLOR),
                                 0)
        ds = td.AugmentImageComponent(ds, imagenet_augmentors)
        ds = td.PrefetchDataZMQ(ds, num_workers)
        self.ds = td.BatchData(ds, batch_size)
        self.ds.reset_state()

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.cuda = cuda
Пример #2
0
    def __init__(self,
                 mode,
                 batch_size=256,
                 shuffle=False,
                 num_workers=25,
                 cache=50000,
                 collate_fn=default_collate,
                 remainder=False,
                 cuda=False,
                 transform=None):
        # enumerate standard imagenet augmentors
        #imagenet_augmentors = fbresnet_augmentor(mode == 'train')
        imagenet_augmentors = [ImgAugTVCompose(transform)]

        # load the lmdb if we can find it
        lmdb_loc = os.path.join(os.environ['IMAGENET'],
                                'ILSVRC-%s.lmdb' % mode)
        ds = td.LMDBData(lmdb_loc, shuffle=False)
        if mode == 'train':
            ds = td.LocallyShuffleData(ds, cache)
        ds = td.PrefetchData(ds, 5000, 1)
        ds = td.LMDBDataPoint(ds)
        #ds = td.MapDataComponent(ds, lambda x: cv2.imdecode(x, cv2.IMREAD_COLOR), 0)
        ds = td.MapDataComponent(
            ds, lambda x: np.asarray(Image.open(io.BytesIO(x)).convert('RGB')),
            0)
        ds = td.AugmentImageComponent(ds, imagenet_augmentors)
        ds = td.PrefetchDataZMQ(ds, num_workers)
        self.ds = td.BatchData(ds, batch_size, remainder=remainder)
        self.ds.reset_state()

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.cuda = cuda
Пример #3
0
def create_dataflow(data_dir: Path,
                    kind: str,
                    batch_size: int,
                    shuffle: bool = True) -> td.DataFlow:

    path = data_dir / "{}.mdb".format(kind)
    ds = td.LMDBData(str(path), shuffle=shuffle)
    ds = td.MapData(ds, _decode_data)
    ds = td.BatchData(ds, batch_size, remainder=False)
    ds = td.MapDataComponent(ds, _squeeze_last, index=1)
    return ds
Пример #4
0
def lmdb_dataflow(lmdb_path, batch_size, input_size, output_size, is_training, test_speed=False):
    df = dataflow.LMDBData(lmdb_path, shuffle=False)
    size = df.size()
    if is_training:
        df = dataflow.LocallyShuffleData(df, buffer_size=2000)
    df = dataflow.PrefetchData(df, nr_prefetch=500, nr_proc=1)
    df = dataflow.LMDBDataPoint(df)
    df = PreprocessData(df, input_size, output_size)
    if is_training:
        df = dataflow.PrefetchDataZMQ(df, nr_proc=8)
    df = dataflow.BatchData(df, batch_size, use_list=True)
    df = dataflow.RepeatedData(df, -1)
    if test_speed:
        dataflow.TestDataSpeed(df, size=1000).start()
    df.reset_state()
    return df, size