예제 #1
0
def _pad_or_clip_point_properties(inputs, pad_or_clip_size):
  """Pads or clips the inputs point properties.

  If pad_or_clip_size is None, it won't perform any action.

  Args:
    inputs: A dictionary containing input tensors.
    pad_or_clip_size: Number of target points to pad or clip to. If None, it
      will not perform the padding.
  """
  inputs[standard_fields.InputDataFields.num_valid_points] = tf.shape(
      inputs[standard_fields.InputDataFields.point_positions])[0]
  if pad_or_clip_size is not None:
    inputs[standard_fields.InputDataFields.num_valid_points] = tf.minimum(
        inputs[standard_fields.InputDataFields.num_valid_points],
        pad_or_clip_size)
    for key in sorted(standard_fields.get_input_point_fields()):
      if key == standard_fields.InputDataFields.num_valid_points:
        continue
      if key in inputs:
        tensor_rank = len(inputs[key].get_shape().as_list())
        padding_shape = [pad_or_clip_size]
        for i in range(1, tensor_rank):
          padding_shape.append(inputs[key].get_shape().as_list()[i])
        inputs[key] = shape_utils.pad_or_clip_nd(
            tensor=inputs[key], output_shape=padding_shape)
예제 #2
0
def apply_mask_to_input_point_tensors(inputs, valid_mask):
    """Applies mask to input point tensors."""
    masked_tensors = {}
    for field in standard_fields.get_input_point_fields():
        if field in inputs:
            if field != standard_fields.InputDataFields.num_valid_points:
                masked_tensors[field] = tf.boolean_mask(
                    inputs[field], valid_mask)
    return masked_tensors
예제 #3
0
def get_batch_size_1_input_points(inputs, b):
    """Returns input dictionary containing tensors with batch size of 1.

  Note that this function only applies its example selection to the point
  tensors.

  Args:
    inputs: A dictionary of tf.Tensors with our input data.
    b: Example index in the batch.

  Returns:
    inputs_1:  A dictionary of tf.Tensors with batch size of one.
  """
    b_1_inputs = {}
    for field in standard_fields.get_input_point_fields():
        if field in inputs:
            b_1_inputs[field] = inputs[field][b]
    return b_1_inputs
예제 #4
0
def split_inputs(inputs,
                 input_field_mapping_fn,
                 image_preprocess_fn_dic,
                 images_points_correspondence_fn):
  """Splits inputs to view_image_inputs, view_indices_2d_inputs, mesh_inputs.

  Args:
    inputs: Input dictionary.
    input_field_mapping_fn: A function that maps the input fields to the
      fields expected by object detection pipeline.
    image_preprocess_fn_dic: A dictionary of image preprocessing functions.
    images_points_correspondence_fn: A function that returns image and points
      correspondences.

  Returns:
    view_image_inputs: A dictionary containing image inputs.
    view_indices_2d_inputs: A dictionary containing indices 2d inputs.
    mesh_inputs: A dictionary containing mesh inputs.
    object_inputs: A dictionary containing object inputs.
    non_tensor_inputs: Other inputs.
  """
  # Initializing empty dictionary for mesh, image, indices_2d and non tensor
  # inputs.
  non_tensor_inputs = {}
  view_image_inputs = {}
  view_indices_2d_inputs = {}
  mesh_inputs = {}
  object_inputs = {}
  if image_preprocess_fn_dic is None:
    image_preprocess_fn_dic = {}
  # Acquire point / image correspondences.
  if images_points_correspondence_fn is not None:
    fn_outputs = images_points_correspondence_fn(inputs)
    if 'points_position' in fn_outputs:
      inputs[standard_fields.InputDataFields
             .point_positions] = fn_outputs['points_position']
    if 'points_intensity' in fn_outputs:
      inputs[standard_fields.InputDataFields
             .point_intensities] = fn_outputs['points_intensity']
    if 'points_elongation' in fn_outputs:
      inputs[standard_fields.InputDataFields
             .point_elongations] = fn_outputs['points_elongation']
    if 'points_normal' in fn_outputs:
      inputs[standard_fields.InputDataFields
             .point_normals] = fn_outputs['points_normal']
    if 'points_color' in fn_outputs:
      inputs[standard_fields.InputDataFields
             .point_colors] = fn_outputs['points_color']
    if 'view_images' in fn_outputs:
      for key in sorted(fn_outputs['view_images']):
        if len(fn_outputs['view_images'][key].shape) != 4:
          raise ValueError(('%s image should have rank 4.' % key))
      view_image_inputs = fn_outputs['view_images']
    if 'view_indices_2d' in fn_outputs:
      for key in sorted(fn_outputs['view_indices_2d']):
        if len(fn_outputs['view_indices_2d'][key].shape) != 3:
          raise ValueError(('%s indices_2d should have rank 3.' % key))
      view_indices_2d_inputs = fn_outputs['view_indices_2d']

  if input_field_mapping_fn is not None:
    inputs = input_field_mapping_fn(inputs)

  # Setting mesh inputs
  mesh_keys = []
  for key in standard_fields.get_input_point_fields():
    if key in inputs:
      mesh_keys.append(key)
  object_keys = []
  for key in standard_fields.get_input_object_fields():
    if key in inputs:
      object_keys.append(key)
  for k, v in inputs.items():
    if k in mesh_keys:
      mesh_inputs[k] = v
    elif k in object_keys:
      object_inputs[k] = v
    else:
      non_tensor_inputs[k] = v
  logging.info('view image inputs')
  logging.info(view_image_inputs)
  logging.info('view indices 2d inputs')
  logging.info(view_indices_2d_inputs)
  logging.info('mesh inputs')
  logging.info(mesh_inputs)
  logging.info('object inputs')
  logging.info(object_inputs)
  logging.info('non_tensor_inputs')
  logging.info(non_tensor_inputs)
  return (view_image_inputs, view_indices_2d_inputs, mesh_inputs, object_inputs,
          non_tensor_inputs)