def callback(data):
  publish_people = people_pub.get_num_connections() > 0
  if publish_people:
    fetch_list = ['part_affinity_fields', 'key_points']
    feed_dict = {'key_point_threshold': pose_params['key_point_threshold']}
    for st in data.sparse_tensors:
      feed_dict[st.name] = [decode_sparse_tensor(st)]
    t0 = rospy.Time.now()
    outputs = pose_detector.compute(fetch_list, feed_dict)
    t1 = rospy.Time.now()
    rospy.loginfo('Mid-stage to stage4: {} sec'.format((t1-t0).to_sec()))
    #people = extract_people(heat_map[0], affinity[0])
    
    affinity = outputs[0][0]
    keypoints = outputs[1][:,1:]
    people = connect_parts(affinity, keypoints, limbs,
                           line_division=pose_params['line_division'],
                           threshold=pose_params['affinity_threshold'])
    h, w = affinity.shape[:2]
    people = [ { pose_detector._part_names[k]: ((keypoints[v][1]+0.5)/w, (keypoints[v][0]+0.5)/h) \
                 for k,v in person.items() } for person in people ]
    t2 = rospy.Time.now()
    time_pub.publish(t2-t1)
    
    rospy.loginfo("Post processing: {} sec".format((t2-t1).to_sec()))
    rospy.loginfo("Total: {} sec".format((t2-t0).to_sec()))
    print("Found {} people.".format(len(people)))

    msg = PersonArray()
    msg.header = data.header
    msg.people = [Person(body_parts=[KeyPoint(name=k, x=x, y=y) \
                                     for k,(x,y) in p.items()]) \
                  for p in people]
    people_pub.publish(msg)
def callback(data):
  if people_pub.get_num_connections() > 0 or stage1_pub.get_num_connections() > 0:
    try:
      cv_image = bridge.imgmsg_to_cv2(data, 'rgb8')
    except CvBridgeError as e:
      rospy.logerr(e)
      return
  
  if people_pub.get_num_connections() > 0:
    persons = pose_detector.detect_keypoints(cv_image, **pose_params)
    msg = PersonArray(header=data.header)
    msg.people = [Person(body_parts=[KeyPoint(name=k, x=x, y=y) \
                                     for k,(x,y) in p.items()]) \
                  for p in persons]
    people_pub.publish(msg)

  if stage1_pub.get_num_connections() > 0:
    msg = SparseTensorArray()
    msg.header = data.header
    fetch_list = [ pose_detector.end_points['stage0'],
                   pose_detector.end_points['stage1_L2'] ]
    feed_dict = { pose_detector.ph_x: cv_image / 255. }
    outputs = pose_detector.sess.run(fetch_list, feed_dict)
    stage0 = encode_sparse_tensors(outputs[0], threshold=0.1)
    stage0.name = 'stage0'
    stage1_L2 = encode_sparse_tensors(outputs[0], threshold=0.1)
    stage1_L2.name = 'stage1_L2'
    msg.sparse_tensors = [stage0, stage1_L2]
    stage1_pub.publish(msg)
def detect_keypoints(req, detector):
  try:
    cv_image = bridge.imgmsg_to_cv2(req.image, 'rgb8')
  except CvBridgeError as e:
    rospy.logerr(e)
    return None
  key_points = detector.detect_keypoints(cv_image, **pose_params)[0]
  msg = DetectKeyPointsResponse()
  msg.key_points = [KeyPoint(name=k, x=x, y=y) \
                    for k,(x,y) in key_points.items()]
  return msg
def detect_people(req):
  try:
    cv_image = bridge.imgmsg_to_cv2(req.image, 'rgb8')
  except CvBridgeError as e:
    rospy.logerr(e)
    return None
  persons = pose_detector.detect_keypoints(cv_image, **pose_params)
  msg = DetectPeopleResponse()
  msg.people = [Person(body_parts=[KeyPoint(name=k, x=x, y=y) \
                                   for k,(x,y) in p.items()]) \
                for p in persons]
  return msg
def callback(data):
  if not people_pub.get_num_connections():
    return
  try:
    cv_image = bridge.imgmsg_to_cv2(data, 'rgb8')
  except CvBridgeError as e:
    rospy.logerr(e)
    return
  persons = pose_detector.detect_keypoints(cv_image, **pose_params)
  msg = PersonArray(header=data.header)
  msg.people = [Person(body_parts=[KeyPoint(name=k, x=x, y=y) \
                                   for k,(x,y) in p.items()]) \
                for p in persons]
  people_pub.publish(msg)
def compute(req):
  # determine tensor to feed
  if req.input == 'image':
    input_stage = 0
    cv_image = bridge.imgmsg_to_cv2(req.image, 'rgb8')
    cv_image = cv_image/255.
    
    if pose_detector.input_shape is not None:
      cv_image = cv2.resize(cv_image, pose_detector.input_shape[::-1])
    image_batch = np.expand_dims(cv_image, 0)
    feed_dict = {pose_detector.ph_x: image_batch}
  else:
    if not req.input.startswith('stage'):
      raise ValueError('Argument "input" must be "image" or starts with "stage".')
    if req.input[5:] not in ['1', '2', '3', '4', '5', '6']:
      raise ValueError('Argument "input" specifies illegal stage.')
    input_stage = int(req.input[5:])
    feed_dict = {
      pose_detector.end_points[req.input+'_L1']: [decode_sparse_tensor(req.affinity_field)],
      pose_detector.end_points[req.input+'_L2']: [decode_sparse_tensor(req.confidence_map)],
      pose_detector.end_points['stage0']: [decode_sparse_tensor(req.feature_map)]
    }
  feed_dict[pose_detector.ph_threshold] = pose_params['key_point_threshold']
  
  # determine tensor to fetch
  if req.output == 'people':
    output_stage = 6
    fetch_list = [pose_detector.heat_map,
                  pose_detector.affinity,
                  pose_detector.keypoints]
    rospy.loginfo('Start processing.')
    heat_map, affinity, keypoints = pose_detector.sess.run(fetch_list, feed_dict)
    rospy.loginfo('Done.')
    # TODO: Scale is not always 8. It varies according to the preprocessing.
    scale_x = 8.
    scale_y = 8.
    inlier_lists = []
    for _,y,x,c in keypoints:
      x = x*scale_x + scale_x/2
      y = y*scale_y + scale_y/2
      inlier_lists.append((x,y))
      
    persons = connect_parts(affinity[0], keypoints[:,1:], pose_detector.limbs,
                            line_division = pose_params['line_division'],
                            threshold = pose_params['affinity_threshold'])
    persons = [{pose_detector.part_names[k]:inlier_lists[v] \
                for k,v in person.items()} for person in persons]
    msg = PersonArray()
    msg.people = [Person(body_parts=[KeyPoint(name=k, x=x, y=y) \
                                     for k,(x,y) in p.items()]) \
                  for p in persons]
    return ComputeResponse(people=msg)
  else:
    if not req.output.startswith('stage'):
      raise ValueError('Argument "output" must be "people" or starts with "stage".')
    if req.output[5:] not in ['1', '2', '3', '4', '5', '6']:
      raise ValueError('Argument "output" specifies illegal stage.')
    output_stage = int(req.output[5:])
    if output_stage < input_stage:
      raise ValueError('Output stage must be greater than input stage.')
    fetch_list = [
      pose_detector.end_points[req.output+'_L1'],
      pose_detector.end_points[req.output+'_L2'],
      pose_detector.end_points['stage0']
    ]
    affinity, heat_map, feat_map = pose_detector.sess.run(fetch_list, feed_dict)
    return ComputeResponse(feature_map=encode_sparse_tensor(feat_map[0], threshold=0.2),
                           affinity_field=encode_sparse_tensor(affinity[0]),
                           confidence_map=encode_sparse_tensor(heat_map[0]))