Exemplo n.º 1
0
    print('epochs:', args.max_epoch)
    if args.gpus <= 0:
        raise Exception('gpus <= 0')

    # define input placeholder
    set_network_input_wh(args.input_width, args.input_height)
    scale = 4

    if args.model in [
            'cmu', 'vgg', 'mobilenet_thin', 'mobilenet_try', 'mobilenet_try2',
            'mobilenet_try3', 'hybridnet_try'
    ]:
        print('model:', args.model)
        scale = 8

    set_network_scale(scale)
    output_w, output_h = args.input_width // scale, args.input_height // scale

    logger.info('define model+')
    with tf.device(tf.DeviceSpec(device_type="CPU")):
        input_node = tf.placeholder(tf.float32,
                                    shape=(args.batchsize, args.input_height,
                                           args.input_width, 3),
                                    name='image')
        vectmap_node = tf.placeholder(tf.float32,
                                      shape=(args.batchsize, output_h,
                                             output_w, 38),
                                      name='vectmap')
        heatmap_node = tf.placeholder(tf.float32,
                                      shape=(args.batchsize, output_h,
                                             output_w, 19),
from tf_pose.pose_dataset import get_dataflow_batch
from tf_pose.pose_augment import set_network_input_wh, set_network_scale

if __name__ == '__main__':
    """
    OpenPose Data Preparation might be a bottleneck for training.
    You can run multiple workers to generate input batches in multi-nodes to make training process faster.
    """
    parser = argparse.ArgumentParser(
        description='Worker for preparing input batches.')
    parser.add_argument('--datapath', type=str, default='/coco/annotations/')
    parser.add_argument('--imgpath', type=str, default='/coco/')
    parser.add_argument('--batchsize', type=int, default=64)
    parser.add_argument('--train', type=bool, default=True)
    parser.add_argument('--master',
                        type=str,
                        default='tcp://csi-cluster-gpu20.dakao.io:1027')
    parser.add_argument('--input-width', type=int, default=368)
    parser.add_argument('--input-height', type=int, default=368)
    parser.add_argument('--scale-factor', type=int, default=2)
    args = parser.parse_args()

    set_network_input_wh(args.input_width, args.input_height)
    set_network_scale(args.scale_factor)

    df = get_dataflow_batch(args.datapath, args.train, args.batchsize,
                            args.imgpath)

    send_dataflow_zmq(df, args.master, hwm=10)