default=32, type=int) parser.add_argument('--benchmark', action='store_true') parser.add_argument('--no-zmq-ops', action='store_true') args = parser.parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = '' if args.fake: ds = FakeData([[args.batch, 224, 224, 3], [args.batch]], 1000, random=False, dtype=['uint8', 'int32']) else: augs = fbresnet_augmentor(True) ds = get_data(args.batch, augs) logger.info("Serving data on {}".format(socket.gethostname())) if args.benchmark: from zmq_ops import dump_arrays ds = MapData(ds, dump_arrays) TestDataSpeed(ds, warmup=300).start() else: format = None if args.no_zmq_ops else 'zmq_ops' send_dataflow_zmq(ds, 'ipc://@imagenet-train-b{}'.format(args.batch), hwm=150, format=format, bind=True)
import tensorpack.dataflow as df if __name__ == '__main__': ds = df.dataset.Mnist('train') augmentors = [ df.imgaug.RandomApplyAug( df.imgaug.RandomResize((0.8, 1.2), (0.8, 1.2)), 0.3), df.imgaug.RandomApplyAug(df.imgaug.RotationAndCropValid(15), 0.5), df.imgaug.RandomApplyAug( df.imgaug.SaltPepperNoise(white_prob=0.01, black_prob=0.01), 0.25), df.imgaug.Resize((28, 28)), df.imgaug.CenterPaste((32, 32)), df.imgaug.RandomCrop((28, 28)), df.imgaug.MapImage(lambda x: x.reshape(28, 28, 1)) ] ds = df.AugmentImageComponent(ds, augmentors) ds = df.BatchData(ds, batch_size=32, remainder=False) ds = df.PrefetchData(ds, nr_prefetch=12, nr_proc=2) ds = df.PrintData(ds) df.send_dataflow_zmq(ds, 'tcp://localhost:2222')
default=32, type=int) parser.add_argument('--warmup', help='prefetch buffer size', default=150, type=int) parser.add_argument('--port', help='server port', default=1000, type=int) parser.add_argument('--benchmark', action='store_true') parser.add_argument('--no-zmq-ops', action='store_true') args = parser.parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = '' if args.fake: ds = FakeData( [[args.batch, args.image_size, args.image_size, 3], [args.batch]], 1000, random=False, dtype=['uint8', 'int32']) else: augs = fbresnet_augmentor(True, image_size=args.image_size) ds = get_data(args.batch, augs, args.worker) logger.info("Serving data on {}".format(socket.gethostname())) if args.benchmark: from zmq_ops import dump_arrays ds = MapData(ds, dump_arrays) TestDataSpeed(ds, warmup=300).start() else: format = None if args.no_zmq_ops else 'zmq_ops' send_dataflow_zmq( ds, 'ipc://@imagenet-train-b{}-p{}'.format(args.batch, args.port), hwm=args.warmup, format=format, bind=True)