def test_check_blob_exists(self):
        print('PUBLIC_BLOB_URI')
        self.assertTrue(check_blob_exists(PUBLIC_BLOB_URI))
        print('PUBLIC_CONTAINER_URI + PUBLIC_BLOB_NAME')
        self.assertTrue(
            check_blob_exists(PUBLIC_CONTAINER_URI,
                              blob_name=PUBLIC_BLOB_NAME))

        print('PUBLIC_CONTAINER_URI')
        with self.assertRaises(IndexError):
            check_blob_exists(PUBLIC_CONTAINER_URI)
        print('PUBLIC_INVALID_BLOB_URI')
        self.assertFalse(check_blob_exists(PUBLIC_INVALID_BLOB_URI))
def visualize_detector_output(detector_output_path: str,
                              out_dir: str,
                              images_dir: str,
                              is_azure: bool = False,
                              confidence: float = 0.8,
                              sample: int = -1,
                              output_image_width: int = 700,
                              random_seed: Optional[int] = None) -> List[str]:
    """Draw bounding boxes on images given the output of the detector.

    Args:
        detector_output_path: str, path to detector output json file
        out_dir: str, path to directory for saving annotated images
        images_dir: str, path to local images dir, or a SAS URL to an Azure Blob
            Storage container
        is_azure: bool, whether images_dir points to an Azure URL
        confidence: float, threshold above which annotations will be rendered
        sample: int, maximum number of images to annotate, -1 for all
        output_image_width: int, width in pixels to resize images for display,
            set to -1 to use original image width
        random_seed: int, for deterministic image sampling when sample != -1

    Returns: list of str, paths to annotated images
    """
    # arguments error checking
    assert confidence > 0 and confidence < 1, (
        f'Confidence threshold {confidence} is invalid, must be in (0, 1).')

    assert os.path.exists(detector_output_path), (
        f'Detector output file does not exist at {detector_output_path}.')

    if is_azure:
        # we don't import sas_blob_utils at the top of this file in order to
        # accommodate the MegaDetector Colab notebook which does not have
        # the azure-storage-blob package installed
        import sas_blob_utils
    else:
        assert os.path.isdir(images_dir)

    os.makedirs(out_dir, exist_ok=True)

    #%% Load detector output

    with open(detector_output_path) as f:
        detector_output = json.load(f)
    assert 'images' in detector_output, (
        'Detector output file should be a json with an "images" field.')
    images = detector_output['images']

    detector_label_map = DEFAULT_DETECTOR_LABEL_MAP
    if 'detection_categories' in detector_output:
        print('detection_categories provided')
        detector_label_map = detector_output['detection_categories']

    num_images = len(images)
    print(f'Detector output file contains {num_images} entries.')

    if sample > 0:
        assert num_images >= sample, (
            f'Sample size {sample} greater than number of entries '
            f'({num_images}) in detector result.')

        if random_seed is not None:
            images = sorted(images, key=lambda x: x['file'])
            random.seed(random_seed)

        random.shuffle(images)
        images = sorted(images[:sample], key=lambda x: x['file'])
        print(f'Sampled {len(images)} entries from the detector output file.')

    #%% Load images, annotate them and save

    print('Starting to annotate the images...')
    num_saved = 0
    annotated_img_paths = []
    image_obj: Any  # str for local images, BytesIO for Azure images

    for entry in tqdm(images):
        image_id = entry['file']

        if 'failure' in entry:
            print(f'Skipping {image_id}, failure: "{entry["failure"]}"')
            continue

        # max_conf = entry['max_detection_conf']

        if is_azure:
            blob_uri = sas_blob_utils.build_blob_uri(container_uri=images_dir,
                                                     blob_name=image_id)
            if not sas_blob_utils.check_blob_exists(blob_uri):
                container = sas_blob_utils.get_container_from_uri(images_dir)
                print(f'Image {image_id} not found in blob container '
                      f'{container}; skipped.')
                continue
            image_obj, _ = sas_blob_utils.download_blob_to_stream(blob_uri)
        else:
            image_obj = os.path.join(images_dir, image_id)
            if not os.path.exists(image_obj):
                print(f'Image {image_id} not found in images_dir; skipped.')
                continue

        # resize is for displaying them more quickly
        image = vis_utils.resize_image(vis_utils.open_image(image_obj),
                                       output_image_width)

        vis_utils.render_detection_bounding_boxes(
            entry['detections'],
            image,
            label_map=detector_label_map,
            confidence_threshold=confidence)

        for char in ['/', '\\', ':']:
            image_id = image_id.replace(char, '~')
        annotated_img_path = os.path.join(out_dir, f'anno_{image_id}')
        annotated_img_paths.append(annotated_img_path)
        image.save(annotated_img_path)
        num_saved += 1

        if is_azure:
            image_obj.close()  # BytesIO object

    print(f'Rendered detection results on {num_saved} images, '
          f'saved to {out_dir}.')

    return annotated_img_paths
Beispiel #3
0
 def check_local_then_azure(local_path: str, blob_url: str) -> bool:
     return (os.path.exists(local_path)
             or sas_blob_utils.check_blob_exists(blob_url))
def check_image_condition(img_path: str,
                          truncated_images_lock: threading.Lock,
                          account: Optional[str] = None,
                          container: Optional[str] = None,
                          sas_token: Optional[str] = None,
                          datasets_table: Optional[Mapping[str, Any]] = None
                          ) -> Tuple[str, str]:
    """
    Args:
        img_path: str, either <blob_name> if datasets_table is None, or
            <dataset>/<blob_name> if datasets_table is given
        account: str, name of Azure Blob Storage account
        container: str, name of Azure Blob Storage container
        sas_token: str, optional SAS token (without leading '?') if the
            container is not publicly accessible
        datasets_table: dict, maps dataset name to dict of information

    Returns: (img_file, status) tuple, where status is one of
        'nonexistant': blob does not exist in the container
        'non_image': img_file does not have valid file extension
        'good': image exists and is able to be opened without setting
            ImageFile.LOAD_TRUNCATED_IMAGES=True
        'truncated': image exists but can only be opened by setting
            ImageFile.LOAD_TRUNCATED_IMAGES=True
        'bad': image exists, but cannot be opened even when setting
            ImageFile.LOAD_TRUNCATED_IMAGES=True
    """
    if (account is None) or (container is None) or (datasets_table is not None):
        assert account is None
        assert container is None
        assert sas_token is None
        assert datasets_table is not None

        dataset, img_file = img_path.split('/', maxsplit=1)
        account = datasets_table[dataset]['storage_account']
        container = datasets_table[dataset]['container']
        sas_token = datasets_table[dataset]['container_sas_key']
        if sas_token[0] == '?':  # strip leading '?' from SAS token
            sas_token = sas_token[1:]
    else:
        img_file = img_path

    if not path_utils.is_image_file(img_file):
        return img_file, 'non_image'

    blob_url = sas_blob_utils.build_azure_storage_uri(
        account=account, container=container, sas_token=sas_token,
        blob=img_file)
    blob_exists = sas_blob_utils.check_blob_exists(blob_url)
    if not blob_exists:
        return img_file, 'nonexistant'

    stream, _ = sas_blob_utils.download_blob_to_stream(blob_url)
    stream.seek(0)
    try:
        with truncated_images_lock:
            ImageFile.LOAD_TRUNCATED_IMAGES = False
            with Image.open(stream) as img:
                img.load()
        return img_file, 'good'
    except OSError as e:  # PIL.UnidentifiedImageError is a subclass of OSError
        try:
            stream.seek(0)
            with truncated_images_lock:
                ImageFile.LOAD_TRUNCATED_IMAGES = True
                with Image.open(stream) as img:
                    img.load()
            return img_file, 'truncated'
        except Exception as e:  # pylint: disable=broad-except
            exception_type = type(e).__name__
            tqdm.write(f'Unable to load {img_file}. {exception_type}: {e}.')
            return img_file, 'bad'
Beispiel #5
0
def request_detections():
    """
    Checks that the input parameters to this endpoint are valid, starts a thread
    to launch the batch processing job, and return the job_id/request_id to the user.
    """
    if not request.is_json:
        msg = 'Body needs to have a JSON mimetype (e.g., application/json).'
        return make_error(415, msg)

    try:
        post_body = request.get_json()
    except Exception as e:
        return make_error(415,
                          f'Error occurred reading POST request body: {e}.')

    app.logger.info(f'server, request_detections, post_body: {post_body}')

    # required params

    caller_id = post_body.get('caller', None)
    if caller_id is None or caller_id not in app_config.get_allowlist():
        msg = ('Parameter caller is not supplied or is not on our allowlist. '
               'Please email [email protected] to request access.')
        return make_error(401, msg)

    use_url = post_body.get('use_url', False)
    if use_url and isinstance(
            use_url,
            str):  # in case it is included but is intended to be False
        if use_url.lower() in ['false', 'f', 'no', 'n']:
            use_url = False

    input_container_sas = post_body.get('input_container_sas', None)
    if not input_container_sas and not use_url:
        msg = ('input_container_sas with read and list access is a required '
               'field when not using image URLs.')
        return make_error(400, msg)

    if input_container_sas is not None:
        if not sas_blob_utils.is_container_uri(input_container_sas):
            return make_error(
                400, 'input_container_sas provided is not for a container.')

        result = check_data_container_sas(input_container_sas)
        if result is not None:
            return make_error(result[0], result[1])

    images_requested_json_sas = post_body.get('images_requested_json_sas',
                                              None)
    if images_requested_json_sas is not None:
        try:
            exists = sas_blob_utils.check_blob_exists(
                images_requested_json_sas)
        except Exception as e:
            return make_error(400, 'images_requested_json_sas is not valid.')
        if not exists:
            return make_error(
                400,
                'images_requested_json_sas points to a non-existent file.')

    # if use_url, then images_requested_json_sas is required
    if use_url:
        if images_requested_json_sas is None:
            msg = 'images_requested_json_sas is required since use_url is true.'
            return make_error(400, msg)

    # optional params

    # check model_version is among the available model versions
    model_version = post_body.get('model_version', '')
    if model_version != '':
        model_version = str(model_version)  # in case user used an int
        if model_version not in api_config.MD_VERSIONS_TO_REL_PATH:  # TODO use AppConfig to store model version info
            return make_error(
                400, f'model_version {model_version} is not supported.')

    # check request_name has only allowed characters
    request_name = post_body.get('request_name', '')
    if request_name != '':
        if len(request_name) > 92:
            return make_error(400,
                              'request_name is longer than 92 characters.')
        allowed = set(string.ascii_letters + string.digits + '_' + '-')
        if not set(request_name) <= allowed:
            msg = ('request_name contains invalid characters (only letters, '
                   'digits, - and _ are allowed).')
            return make_error(400, msg)

    # optional params for telemetry collection - logged to status table for now as part of call_params
    country = post_body.get('country', None)
    organization_name = post_body.get('organization_name', None)

    # All API instances / node pools share a quota on total number of active Jobs;
    # we cannot accept new Job submissions if we are at the quota
    try:
        num_active_jobs = batch_job_manager.get_num_active_jobs()
        if num_active_jobs >= api_config.MAX_BATCH_ACCOUNT_ACTIVE_JOBS:
            return make_error(503,
                              f'Too many active jobs, please try again later')
    except Exception as e:
        return make_error(500, f'Error checking number of active jobs: {e}')

    try:
        job_id = uuid.uuid4().hex
        job_status_table.create_job_status(
            job_id=job_id,
            status=get_job_status('created',
                                  'Request received. Listing images next...'),
            call_params=post_body)
    except Exception as e:
        return make_error(500, f'Error creating a job status entry: {e}')

    try:
        thread = threading.Thread(target=create_batch_job,
                                  name=f'job_{job_id}',
                                  kwargs={
                                      'job_id': job_id,
                                      'body': post_body
                                  })
        thread.start()
    except Exception as e:
        return make_error(
            500,
            f'Error creating or starting the batch processing thread: {e}')

    return {'request_id': job_id}