def _get_voxels_valid_inputs_outputs(inputs_1, outputs_1): """Applies the valid mask to input and output voxel tensors.""" valid_mask = _get_voxels_valid_mask(inputs_1=inputs_1) inputs_1 = mask_utils.apply_mask_to_input_voxel_tensors( inputs=inputs_1, valid_mask=valid_mask) mask_utils.apply_mask_to_output_voxel_tensors(outputs=outputs_1, valid_mask=valid_mask) return inputs_1, outputs_1
def mask_valid_voxels(inputs, outputs): """Mask the voxels that are valid.""" if standard_fields.DetectionResultFields.objects_center in outputs: return outputs valid_mask = mask_utils.num_voxels_mask(inputs=inputs) mask_utils.apply_mask_to_output_voxel_tensors(outputs=outputs, valid_mask=valid_mask) for key, value in standard_fields.get_output_voxel_to_object_field_mapping( ).items(): if key in outputs: outputs[value] = outputs[key]
def mask_valid_voxels(inputs, outputs): """Mask the voxels that are valid and in image view.""" valid_mask = mask_utils.num_voxels_mask(inputs=inputs) mask_utils.apply_mask_to_output_voxel_tensors(outputs=outputs, valid_mask=valid_mask)