示例#1
0
def get_remote_dataflow(port, nr_prefetch=1000, nr_thread=1):
    ipc = 'ipc:///tmp/ipc-socket'
    tcp = 'tcp://0.0.0.0:%d' % port
    data_loader = RemoteDataZMQ(ipc, tcp, hwm=10000)
    data_loader = BatchData(data_loader, batch_size=hp.train.batch_size)
    data_loader = PrefetchData(data_loader, nr_prefetch, nr_thread)
    return data_loader
示例#2
0
logging.basicConfig(
    level=logging.DEBUG,
    format='[lmdb_dataset] %(asctime)s %(levelname)s %(message)s')

if __name__ == '__main__':
    """
    Speed Test for Getting Input batches from other nodes
    """
    parser = argparse.ArgumentParser(
        description='Worker for preparing input batches.')
    parser.add_argument('--listen', type=str, default='tcp://0.0.0.0:1027')
    parser.add_argument('--show', type=bool, default=False)
    args = parser.parse_args()

    df = RemoteDataZMQ(args.listen)

    logging.info('tcp queue start')
    df.reset_state()
    t = time.time()
    for i, dp in enumerate(df.get_data()):
        if i == 100:
            break
        logging.info('Input batch %d received.' % i)
        if i == 0:
            for d in dp:
                logging.info('%d dp shape={}'.format(d.shape))

        if args.show:
            CocoPose.display_image(dp[0][0], dp[1][0], dp[2][0])
示例#3
0
                                             output_w, 19),
                                      name='heatmap')
        object_node = tf.placeholder(tf.float32,
                                     shape=(args.batchsize, 2535, 85),
                                     name='yolo_out')

        # prepare data
        if not args.remote_data:
            df = get_dataflow_batch(args.datapath,
                                    True,
                                    args.batchsize,
                                    img_path=args.imgpath)
            print('train_df', df)
        else:
            # transfer inputs from ZMQ
            df = RemoteDataZMQ(args.remote_data, hwm=3)
        enqueuer = DataFlowToQueue(
            df, [input_node, heatmap_node, vectmap_node, object_node],
            queue_size=100)
        q_inp, q_heat, q_vect, q_obj = enqueuer.dequeue()
        print('inp/out', q_inp, q_heat, q_vect, q_obj)

    df_valid = get_dataflow_batch(args.datapath,
                                  False,
                                  args.batchsize,
                                  img_path=args.imgpath)
    print('val_df', df_valid)
    df_valid.reset_state()
    validation_cache = []

    val_image = get_sample_images(args.input_width, args.input_height)
示例#4
0
                                  shape=(args.batchsize, output_h, output_w,
                                         38),
                                  name='vectmap')
    heatmap_node = tf.placeholder(tf.float32,
                                  shape=(args.batchsize, output_h, output_w,
                                         19),
                                  name='heatmap')
    # prepare data
    if not args.remote_data:
        df = get_dataflow_batch(args.datapath,
                                True,
                                args.batchsize,
                                img_path=args.imgpath)
    else:
        # transfer inputs from ZMQ
        df = RemoteDataZMQ(args.remote_data, hwm=3)
    # enqueuer = DataFlowToQueue(df, [input_node, heatmap_node, vectmap_node], queue_size=100)
    q_inp, q_heat, q_vect = input_node, heatmap_node, vectmap_node  # enqueuer.dequeue()

    #df_valid = get_dataflow_batch(args.datapath, False, args.batchsize, img_path=args.imgpath)
    # df_valid.reset_state()
    validation_cache = []

    # val_image = get_sample_images(args.input_width, args.input_height)
    # logger.info('tensorboard val image: %d' % len(val_image))
    logger.info(q_inp)
    logger.info(q_heat)
    logger.info(q_vect)

    # define model for multi-gpu
    q_inp_split, q_heat_split, q_vect_split = tf.split(
from tensorpack.dataflow.remote import RemoteDataZMQ

from pose_dataset import CocoPose

logging.basicConfig(level=logging.DEBUG, format='[lmdb_dataset] %(asctime)s %(levelname)s %(message)s')

if __name__ == '__main__':
    """
    Speed Test for Getting Input batches from other nodes
    """
    parser = argparse.ArgumentParser(description='Worker for preparing input batches.')
    parser.add_argument('--listen', type=str, default='tcp://0.0.0.0:1027')
    parser.add_argument('--show', type=bool, default=False)
    args = parser.parse_args()

    df = RemoteDataZMQ(args.listen)

    logging.info('tcp queue start')
    df.reset_state()
    t = time.time()
    for i, dp in enumerate(df.get_data()):
        if i == 100:
            break
        logging.info('Input batch %d received.' % i)
        if i == 0:
            for d in dp:
                logging.info('%d dp shape={}'.format(d.shape))

        if args.show:
            CocoPose.display_image(dp[0][0], dp[1][0], dp[2][0])