def _pad_or_clip_point_properties(inputs, pad_or_clip_size): """Pads or clips the inputs point properties. If pad_or_clip_size is None, it won't perform any action. Args: inputs: A dictionary containing input tensors. pad_or_clip_size: Number of target points to pad or clip to. If None, it will not perform the padding. """ inputs[standard_fields.InputDataFields.num_valid_points] = tf.shape( inputs[standard_fields.InputDataFields.point_positions])[0] if pad_or_clip_size is not None: inputs[standard_fields.InputDataFields.num_valid_points] = tf.minimum( inputs[standard_fields.InputDataFields.num_valid_points], pad_or_clip_size) for key in sorted(standard_fields.get_input_point_fields()): if key == standard_fields.InputDataFields.num_valid_points: continue if key in inputs: tensor_rank = len(inputs[key].get_shape().as_list()) padding_shape = [pad_or_clip_size] for i in range(1, tensor_rank): padding_shape.append(inputs[key].get_shape().as_list()[i]) inputs[key] = shape_utils.pad_or_clip_nd( tensor=inputs[key], output_shape=padding_shape)
def apply_mask_to_input_point_tensors(inputs, valid_mask): """Applies mask to input point tensors.""" masked_tensors = {} for field in standard_fields.get_input_point_fields(): if field in inputs: if field != standard_fields.InputDataFields.num_valid_points: masked_tensors[field] = tf.boolean_mask( inputs[field], valid_mask) return masked_tensors
def get_batch_size_1_input_points(inputs, b): """Returns input dictionary containing tensors with batch size of 1. Note that this function only applies its example selection to the point tensors. Args: inputs: A dictionary of tf.Tensors with our input data. b: Example index in the batch. Returns: inputs_1: A dictionary of tf.Tensors with batch size of one. """ b_1_inputs = {} for field in standard_fields.get_input_point_fields(): if field in inputs: b_1_inputs[field] = inputs[field][b] return b_1_inputs
def split_inputs(inputs, input_field_mapping_fn, image_preprocess_fn_dic, images_points_correspondence_fn): """Splits inputs to view_image_inputs, view_indices_2d_inputs, mesh_inputs. Args: inputs: Input dictionary. input_field_mapping_fn: A function that maps the input fields to the fields expected by object detection pipeline. image_preprocess_fn_dic: A dictionary of image preprocessing functions. images_points_correspondence_fn: A function that returns image and points correspondences. Returns: view_image_inputs: A dictionary containing image inputs. view_indices_2d_inputs: A dictionary containing indices 2d inputs. mesh_inputs: A dictionary containing mesh inputs. object_inputs: A dictionary containing object inputs. non_tensor_inputs: Other inputs. """ # Initializing empty dictionary for mesh, image, indices_2d and non tensor # inputs. non_tensor_inputs = {} view_image_inputs = {} view_indices_2d_inputs = {} mesh_inputs = {} object_inputs = {} if image_preprocess_fn_dic is None: image_preprocess_fn_dic = {} # Acquire point / image correspondences. if images_points_correspondence_fn is not None: fn_outputs = images_points_correspondence_fn(inputs) if 'points_position' in fn_outputs: inputs[standard_fields.InputDataFields .point_positions] = fn_outputs['points_position'] if 'points_intensity' in fn_outputs: inputs[standard_fields.InputDataFields .point_intensities] = fn_outputs['points_intensity'] if 'points_elongation' in fn_outputs: inputs[standard_fields.InputDataFields .point_elongations] = fn_outputs['points_elongation'] if 'points_normal' in fn_outputs: inputs[standard_fields.InputDataFields .point_normals] = fn_outputs['points_normal'] if 'points_color' in fn_outputs: inputs[standard_fields.InputDataFields .point_colors] = fn_outputs['points_color'] if 'view_images' in fn_outputs: for key in sorted(fn_outputs['view_images']): if len(fn_outputs['view_images'][key].shape) != 4: raise ValueError(('%s image should have rank 4.' % key)) view_image_inputs = fn_outputs['view_images'] if 'view_indices_2d' in fn_outputs: for key in sorted(fn_outputs['view_indices_2d']): if len(fn_outputs['view_indices_2d'][key].shape) != 3: raise ValueError(('%s indices_2d should have rank 3.' % key)) view_indices_2d_inputs = fn_outputs['view_indices_2d'] if input_field_mapping_fn is not None: inputs = input_field_mapping_fn(inputs) # Setting mesh inputs mesh_keys = [] for key in standard_fields.get_input_point_fields(): if key in inputs: mesh_keys.append(key) object_keys = [] for key in standard_fields.get_input_object_fields(): if key in inputs: object_keys.append(key) for k, v in inputs.items(): if k in mesh_keys: mesh_inputs[k] = v elif k in object_keys: object_inputs[k] = v else: non_tensor_inputs[k] = v logging.info('view image inputs') logging.info(view_image_inputs) logging.info('view indices 2d inputs') logging.info(view_indices_2d_inputs) logging.info('mesh inputs') logging.info(mesh_inputs) logging.info('object inputs') logging.info(object_inputs) logging.info('non_tensor_inputs') logging.info(non_tensor_inputs) return (view_image_inputs, view_indices_2d_inputs, mesh_inputs, object_inputs, non_tensor_inputs)