def get_remote_dataflow(port, nr_prefetch=1000, nr_thread=1): ipc = 'ipc:///tmp/ipc-socket' tcp = 'tcp://0.0.0.0:%d' % port data_loader = RemoteDataZMQ(ipc, tcp, hwm=10000) data_loader = BatchData(data_loader, batch_size=hp.train.batch_size) data_loader = PrefetchData(data_loader, nr_prefetch, nr_thread) return data_loader
logging.basicConfig( level=logging.DEBUG, format='[lmdb_dataset] %(asctime)s %(levelname)s %(message)s') if __name__ == '__main__': """ Speed Test for Getting Input batches from other nodes """ parser = argparse.ArgumentParser( description='Worker for preparing input batches.') parser.add_argument('--listen', type=str, default='tcp://0.0.0.0:1027') parser.add_argument('--show', type=bool, default=False) args = parser.parse_args() df = RemoteDataZMQ(args.listen) logging.info('tcp queue start') df.reset_state() t = time.time() for i, dp in enumerate(df.get_data()): if i == 100: break logging.info('Input batch %d received.' % i) if i == 0: for d in dp: logging.info('%d dp shape={}'.format(d.shape)) if args.show: CocoPose.display_image(dp[0][0], dp[1][0], dp[2][0])
output_w, 19), name='heatmap') object_node = tf.placeholder(tf.float32, shape=(args.batchsize, 2535, 85), name='yolo_out') # prepare data if not args.remote_data: df = get_dataflow_batch(args.datapath, True, args.batchsize, img_path=args.imgpath) print('train_df', df) else: # transfer inputs from ZMQ df = RemoteDataZMQ(args.remote_data, hwm=3) enqueuer = DataFlowToQueue( df, [input_node, heatmap_node, vectmap_node, object_node], queue_size=100) q_inp, q_heat, q_vect, q_obj = enqueuer.dequeue() print('inp/out', q_inp, q_heat, q_vect, q_obj) df_valid = get_dataflow_batch(args.datapath, False, args.batchsize, img_path=args.imgpath) print('val_df', df_valid) df_valid.reset_state() validation_cache = [] val_image = get_sample_images(args.input_width, args.input_height)
shape=(args.batchsize, output_h, output_w, 38), name='vectmap') heatmap_node = tf.placeholder(tf.float32, shape=(args.batchsize, output_h, output_w, 19), name='heatmap') # prepare data if not args.remote_data: df = get_dataflow_batch(args.datapath, True, args.batchsize, img_path=args.imgpath) else: # transfer inputs from ZMQ df = RemoteDataZMQ(args.remote_data, hwm=3) # enqueuer = DataFlowToQueue(df, [input_node, heatmap_node, vectmap_node], queue_size=100) q_inp, q_heat, q_vect = input_node, heatmap_node, vectmap_node # enqueuer.dequeue() #df_valid = get_dataflow_batch(args.datapath, False, args.batchsize, img_path=args.imgpath) # df_valid.reset_state() validation_cache = [] # val_image = get_sample_images(args.input_width, args.input_height) # logger.info('tensorboard val image: %d' % len(val_image)) logger.info(q_inp) logger.info(q_heat) logger.info(q_vect) # define model for multi-gpu q_inp_split, q_heat_split, q_vect_split = tf.split(
from tensorpack.dataflow.remote import RemoteDataZMQ from pose_dataset import CocoPose logging.basicConfig(level=logging.DEBUG, format='[lmdb_dataset] %(asctime)s %(levelname)s %(message)s') if __name__ == '__main__': """ Speed Test for Getting Input batches from other nodes """ parser = argparse.ArgumentParser(description='Worker for preparing input batches.') parser.add_argument('--listen', type=str, default='tcp://0.0.0.0:1027') parser.add_argument('--show', type=bool, default=False) args = parser.parse_args() df = RemoteDataZMQ(args.listen) logging.info('tcp queue start') df.reset_state() t = time.time() for i, dp in enumerate(df.get_data()): if i == 100: break logging.info('Input batch %d received.' % i) if i == 0: for d in dp: logging.info('%d dp shape={}'.format(d.shape)) if args.show: CocoPose.display_image(dp[0][0], dp[1][0], dp[2][0])