def test_check_blob_exists(self): print('PUBLIC_BLOB_URI') self.assertTrue(check_blob_exists(PUBLIC_BLOB_URI)) print('PUBLIC_CONTAINER_URI + PUBLIC_BLOB_NAME') self.assertTrue( check_blob_exists(PUBLIC_CONTAINER_URI, blob_name=PUBLIC_BLOB_NAME)) print('PUBLIC_CONTAINER_URI') with self.assertRaises(IndexError): check_blob_exists(PUBLIC_CONTAINER_URI) print('PUBLIC_INVALID_BLOB_URI') self.assertFalse(check_blob_exists(PUBLIC_INVALID_BLOB_URI))
def visualize_detector_output(detector_output_path: str, out_dir: str, images_dir: str, is_azure: bool = False, confidence: float = 0.8, sample: int = -1, output_image_width: int = 700, random_seed: Optional[int] = None) -> List[str]: """Draw bounding boxes on images given the output of the detector. Args: detector_output_path: str, path to detector output json file out_dir: str, path to directory for saving annotated images images_dir: str, path to local images dir, or a SAS URL to an Azure Blob Storage container is_azure: bool, whether images_dir points to an Azure URL confidence: float, threshold above which annotations will be rendered sample: int, maximum number of images to annotate, -1 for all output_image_width: int, width in pixels to resize images for display, set to -1 to use original image width random_seed: int, for deterministic image sampling when sample != -1 Returns: list of str, paths to annotated images """ # arguments error checking assert confidence > 0 and confidence < 1, ( f'Confidence threshold {confidence} is invalid, must be in (0, 1).') assert os.path.exists(detector_output_path), ( f'Detector output file does not exist at {detector_output_path}.') if is_azure: # we don't import sas_blob_utils at the top of this file in order to # accommodate the MegaDetector Colab notebook which does not have # the azure-storage-blob package installed import sas_blob_utils else: assert os.path.isdir(images_dir) os.makedirs(out_dir, exist_ok=True) #%% Load detector output with open(detector_output_path) as f: detector_output = json.load(f) assert 'images' in detector_output, ( 'Detector output file should be a json with an "images" field.') images = detector_output['images'] detector_label_map = DEFAULT_DETECTOR_LABEL_MAP if 'detection_categories' in detector_output: print('detection_categories provided') detector_label_map = detector_output['detection_categories'] num_images = len(images) print(f'Detector output file contains {num_images} entries.') if sample > 0: assert num_images >= sample, ( f'Sample size {sample} greater than number of entries ' f'({num_images}) in detector result.') if random_seed is not None: images = sorted(images, key=lambda x: x['file']) random.seed(random_seed) random.shuffle(images) images = sorted(images[:sample], key=lambda x: x['file']) print(f'Sampled {len(images)} entries from the detector output file.') #%% Load images, annotate them and save print('Starting to annotate the images...') num_saved = 0 annotated_img_paths = [] image_obj: Any # str for local images, BytesIO for Azure images for entry in tqdm(images): image_id = entry['file'] if 'failure' in entry: print(f'Skipping {image_id}, failure: "{entry["failure"]}"') continue # max_conf = entry['max_detection_conf'] if is_azure: blob_uri = sas_blob_utils.build_blob_uri(container_uri=images_dir, blob_name=image_id) if not sas_blob_utils.check_blob_exists(blob_uri): container = sas_blob_utils.get_container_from_uri(images_dir) print(f'Image {image_id} not found in blob container ' f'{container}; skipped.') continue image_obj, _ = sas_blob_utils.download_blob_to_stream(blob_uri) else: image_obj = os.path.join(images_dir, image_id) if not os.path.exists(image_obj): print(f'Image {image_id} not found in images_dir; skipped.') continue # resize is for displaying them more quickly image = vis_utils.resize_image(vis_utils.open_image(image_obj), output_image_width) vis_utils.render_detection_bounding_boxes( entry['detections'], image, label_map=detector_label_map, confidence_threshold=confidence) for char in ['/', '\\', ':']: image_id = image_id.replace(char, '~') annotated_img_path = os.path.join(out_dir, f'anno_{image_id}') annotated_img_paths.append(annotated_img_path) image.save(annotated_img_path) num_saved += 1 if is_azure: image_obj.close() # BytesIO object print(f'Rendered detection results on {num_saved} images, ' f'saved to {out_dir}.') return annotated_img_paths
def check_local_then_azure(local_path: str, blob_url: str) -> bool: return (os.path.exists(local_path) or sas_blob_utils.check_blob_exists(blob_url))
def check_image_condition(img_path: str, truncated_images_lock: threading.Lock, account: Optional[str] = None, container: Optional[str] = None, sas_token: Optional[str] = None, datasets_table: Optional[Mapping[str, Any]] = None ) -> Tuple[str, str]: """ Args: img_path: str, either <blob_name> if datasets_table is None, or <dataset>/<blob_name> if datasets_table is given account: str, name of Azure Blob Storage account container: str, name of Azure Blob Storage container sas_token: str, optional SAS token (without leading '?') if the container is not publicly accessible datasets_table: dict, maps dataset name to dict of information Returns: (img_file, status) tuple, where status is one of 'nonexistant': blob does not exist in the container 'non_image': img_file does not have valid file extension 'good': image exists and is able to be opened without setting ImageFile.LOAD_TRUNCATED_IMAGES=True 'truncated': image exists but can only be opened by setting ImageFile.LOAD_TRUNCATED_IMAGES=True 'bad': image exists, but cannot be opened even when setting ImageFile.LOAD_TRUNCATED_IMAGES=True """ if (account is None) or (container is None) or (datasets_table is not None): assert account is None assert container is None assert sas_token is None assert datasets_table is not None dataset, img_file = img_path.split('/', maxsplit=1) account = datasets_table[dataset]['storage_account'] container = datasets_table[dataset]['container'] sas_token = datasets_table[dataset]['container_sas_key'] if sas_token[0] == '?': # strip leading '?' from SAS token sas_token = sas_token[1:] else: img_file = img_path if not path_utils.is_image_file(img_file): return img_file, 'non_image' blob_url = sas_blob_utils.build_azure_storage_uri( account=account, container=container, sas_token=sas_token, blob=img_file) blob_exists = sas_blob_utils.check_blob_exists(blob_url) if not blob_exists: return img_file, 'nonexistant' stream, _ = sas_blob_utils.download_blob_to_stream(blob_url) stream.seek(0) try: with truncated_images_lock: ImageFile.LOAD_TRUNCATED_IMAGES = False with Image.open(stream) as img: img.load() return img_file, 'good' except OSError as e: # PIL.UnidentifiedImageError is a subclass of OSError try: stream.seek(0) with truncated_images_lock: ImageFile.LOAD_TRUNCATED_IMAGES = True with Image.open(stream) as img: img.load() return img_file, 'truncated' except Exception as e: # pylint: disable=broad-except exception_type = type(e).__name__ tqdm.write(f'Unable to load {img_file}. {exception_type}: {e}.') return img_file, 'bad'
def request_detections(): """ Checks that the input parameters to this endpoint are valid, starts a thread to launch the batch processing job, and return the job_id/request_id to the user. """ if not request.is_json: msg = 'Body needs to have a JSON mimetype (e.g., application/json).' return make_error(415, msg) try: post_body = request.get_json() except Exception as e: return make_error(415, f'Error occurred reading POST request body: {e}.') app.logger.info(f'server, request_detections, post_body: {post_body}') # required params caller_id = post_body.get('caller', None) if caller_id is None or caller_id not in app_config.get_allowlist(): msg = ('Parameter caller is not supplied or is not on our allowlist. ' 'Please email [email protected] to request access.') return make_error(401, msg) use_url = post_body.get('use_url', False) if use_url and isinstance( use_url, str): # in case it is included but is intended to be False if use_url.lower() in ['false', 'f', 'no', 'n']: use_url = False input_container_sas = post_body.get('input_container_sas', None) if not input_container_sas and not use_url: msg = ('input_container_sas with read and list access is a required ' 'field when not using image URLs.') return make_error(400, msg) if input_container_sas is not None: if not sas_blob_utils.is_container_uri(input_container_sas): return make_error( 400, 'input_container_sas provided is not for a container.') result = check_data_container_sas(input_container_sas) if result is not None: return make_error(result[0], result[1]) images_requested_json_sas = post_body.get('images_requested_json_sas', None) if images_requested_json_sas is not None: try: exists = sas_blob_utils.check_blob_exists( images_requested_json_sas) except Exception as e: return make_error(400, 'images_requested_json_sas is not valid.') if not exists: return make_error( 400, 'images_requested_json_sas points to a non-existent file.') # if use_url, then images_requested_json_sas is required if use_url: if images_requested_json_sas is None: msg = 'images_requested_json_sas is required since use_url is true.' return make_error(400, msg) # optional params # check model_version is among the available model versions model_version = post_body.get('model_version', '') if model_version != '': model_version = str(model_version) # in case user used an int if model_version not in api_config.MD_VERSIONS_TO_REL_PATH: # TODO use AppConfig to store model version info return make_error( 400, f'model_version {model_version} is not supported.') # check request_name has only allowed characters request_name = post_body.get('request_name', '') if request_name != '': if len(request_name) > 92: return make_error(400, 'request_name is longer than 92 characters.') allowed = set(string.ascii_letters + string.digits + '_' + '-') if not set(request_name) <= allowed: msg = ('request_name contains invalid characters (only letters, ' 'digits, - and _ are allowed).') return make_error(400, msg) # optional params for telemetry collection - logged to status table for now as part of call_params country = post_body.get('country', None) organization_name = post_body.get('organization_name', None) # All API instances / node pools share a quota on total number of active Jobs; # we cannot accept new Job submissions if we are at the quota try: num_active_jobs = batch_job_manager.get_num_active_jobs() if num_active_jobs >= api_config.MAX_BATCH_ACCOUNT_ACTIVE_JOBS: return make_error(503, f'Too many active jobs, please try again later') except Exception as e: return make_error(500, f'Error checking number of active jobs: {e}') try: job_id = uuid.uuid4().hex job_status_table.create_job_status( job_id=job_id, status=get_job_status('created', 'Request received. Listing images next...'), call_params=post_body) except Exception as e: return make_error(500, f'Error creating a job status entry: {e}') try: thread = threading.Thread(target=create_batch_job, name=f'job_{job_id}', kwargs={ 'job_id': job_id, 'body': post_body }) thread.start() except Exception as e: return make_error( 500, f'Error creating or starting the batch processing thread: {e}') return {'request_id': job_id}