def apply_mask_to_input_voxel_tensors(inputs, valid_mask): """Applies mask to input voxel tensors.""" masked_tensors = {} for field in standard_fields.get_input_voxel_fields(): if field in inputs: if field != standard_fields.InputDataFields.num_valid_voxels: masked_tensors[field] = tf.boolean_mask( inputs[field], valid_mask) return masked_tensors
def get_batch_size_1_input_voxels(inputs, b): """Returns input dictionary containing tensors with batch size of 1. Note that this function only applies its example selection to the voxel tensors. Args: inputs: A dictionary of tf.Tensors with our input data. b: Example index in the batch. Returns: inputs_1: A dictionary of tf.Tensors with batch size of one. """ b_1_inputs = {} for field in standard_fields.get_input_voxel_fields(): if field in inputs: b_1_inputs[field] = inputs[field][b] return b_1_inputs