示例#1
0
def get_image_sas_uris(img_paths: Iterable[str]) -> List[str]:
    """Converts a image paths to Azure Blob Storage blob URIs with SAS tokens.

    Args:
        img_paths: list of str, <dataset-name>/<image-filename>

    Returns:
        image_sas_uris: list of str, image blob URIs with SAS tokens, ready to
            pass to the batch detection API
    """
    # we need the datasets table for getting SAS keys
    datasets_table = megadb_utils.MegadbUtils().get_datasets_table()

    image_sas_uris = []
    for img_path in img_paths:
        dataset, img_file = img_path.split('/', maxsplit=1)

        # strip leading '?' from SAS token
        sas_token = datasets_table[dataset]['container_sas_key']
        if sas_token[0] == '?':
            sas_token = sas_token[1:]

        image_sas_uri = sas_blob_utils.build_azure_storage_uri(
            account=datasets_table[dataset]['storage_account'],
            container=datasets_table[dataset]['container'],
            blob=img_file,
            sas_token=sas_token)
        image_sas_uris.append(image_sas_uri)
    return image_sas_uris
示例#2
0
def get_output_json(
        label_to_inclusions: Dict[str, Set[Tuple[str, str]]],
        mislabeled_images_dir: Optional[str] = None
) -> Dict[str, Dict[str, Any]]:
    """Queries MegaDB to get image paths matching dataset_labels.

    Args:
        label_to_inclusions: dict, maps label name to set of
            (dataset, dataset_label) tuples, output of validate_json()
        mislabeled_images_dir: str, path to directory of CSVs with known
            mislabeled images

    Returns: dict, maps sorted image_path <dataset>/<img_file> to a dict of
        properties
        - 'dataset': str, name of dataset that image is from
        - 'location': str or int, optional
        - 'class': str, class label from the dataset
        - 'label': list of str, assigned output label
        - 'bbox': list of dicts, optional
    """
    # because MegaDB is organized by dataset, we do the same
    # ds_to_labels = {
    #     'dataset_name': {
    #         'dataset_label': [output_label1, output_label2]
    #     }
    # }
    ds_to_labels: Dict[str, Dict[str, List[str]]] = {}
    for output_label, ds_dslabels_set in label_to_inclusions.items():
        for (ds, ds_label) in ds_dslabels_set:
            if ds not in ds_to_labels:
                ds_to_labels[ds] = {}
            if ds_label not in ds_to_labels[ds]:
                ds_to_labels[ds][ds_label] = []
            ds_to_labels[ds][ds_label].append(output_label)

    # we need the datasets table for getting full image paths
    megadb = megadb_utils.MegadbUtils()
    datasets_table = megadb.get_datasets_table()

    # The line
    #    [img.class[0], seq.class[0]][0] as class
    # selects the image-level class label if available. Otherwise it selects the
    # sequence-level class label. This line assumes the following conditions,
    # expressed in the WHERE clause:
    # - at least one of the image or sequence class label is given
    # - the image and sequence class labels are arrays with length at most 1
    # - the image class label takes priority over the sequence class label
    #
    # In Azure Cosmos DB, if a field is not defined, then it is simply excluded
    # from the result. For example, on the following JSON object,
    #     {
    #         "dataset": "camera_traps",
    #         "seq_id": "1234",
    #         "location": "A1",
    #         "images": [{"file": "abcd.jpeg"}],
    #         "class": ["deer"],
    #     }
    # the array [img.class[0], seq.class[0]] just gives ['deer'] because
    # img.class is undefined and therefore excluded.
    query = '''
    SELECT
        seq.dataset,
        seq.location,
        img.file,
        [img.class[0], seq.class[0]][0] as class,
        img.bbox
    FROM sequences seq JOIN img IN seq.images
    WHERE (ARRAY_LENGTH(img.class) = 1
           AND ARRAY_CONTAINS(@dataset_labels, img.class[0])
        )
        OR (ARRAY_LENGTH(seq.class) = 1
            AND ARRAY_CONTAINS(@dataset_labels, seq.class[0])
            AND (NOT IS_DEFINED(img.class))
        )
    '''

    output_json = {}  # maps full image path to json object
    for ds in tqdm(sorted(ds_to_labels.keys())):  # sort for determinism
        mislabeled_images: Mapping[str, Any] = {}
        if mislabeled_images_dir is not None:
            csv_path = os.path.join(mislabeled_images_dir, f'{ds}.csv')
            if os.path.exists(csv_path):
                mislabeled_images = pd.read_csv(csv_path,
                                                index_col='file',
                                                squeeze=True)

        ds_labels = sorted(ds_to_labels[ds].keys())
        tqdm.write(f'Querying dataset "{ds}" for dataset labels: {ds_labels}')

        start = datetime.now()
        parameters = [dict(name='@dataset_labels', value=ds_labels)]
        results = megadb.query_sequences_table(query,
                                               partition_key=ds,
                                               parameters=parameters)
        elapsed = (datetime.now() - start).total_seconds()
        tqdm.write(f'- query took {elapsed:.0f}s, found {len(results)} images')

        # if no path prefix, set it to the empty string '', because
        #     os.path.join('', x, '') = '{x}/'
        path_prefix = datasets_table[ds].get('path_prefix', '')
        count_corrected = 0
        count_removed = 0
        for result in results:
            # result keys
            # - already has: ['dataset', 'location', 'file', 'class', 'bbox']
            # - add ['label'], remove ['file']
            img_file = os.path.join(path_prefix, result['file'])

            # if img is mislabeled, but we don't know the correct class, skip it
            # otherwise, update the img with the correct class, but skip the
            #   img if the correct class is not one we queried for
            if img_file in mislabeled_images:
                new_class = mislabeled_images[img_file]
                if pd.isna(new_class) or new_class not in ds_to_labels[ds]:
                    count_removed += 1
                    continue

                count_corrected += 1
                result['class'] = new_class

            img_path = os.path.join(ds, img_file)
            del result['file']
            ds_label = result['class']
            result['label'] = ds_to_labels[ds][ds_label]
            output_json[img_path] = result

        tqdm.write(f'- Removed {count_removed} mislabeled images.')
        tqdm.write(f'- Corrected labels for {count_corrected} images.')

    # sort keys for determinism
    output_json = {k: output_json[k] for k in sorted(output_json.keys())}
    return output_json
def download_and_crop(
        queried_images_json: Mapping[str, Mapping[str, Any]],
        detection_cache: Mapping[str, Mapping[str, Mapping[str, Any]]],
        detection_categories: Mapping[str, str],
        detector_version: str,
        cropped_images_dir: str,
        confidence_threshold: float,
        save_full_images: bool,
        square_crops: bool,
        check_crops_valid: bool,
        images_dir: Optional[str] = None,
        threads: int = 1,
        images_missing_detections: Optional[Iterable[str]] = None
        ) -> Tuple[List[str], int, int]:
    """
    Saves crops to a file with the same name as the original image with an
    additional suffix appended, starting with 3 underscores:
    - if image has ground truth bboxes: "___cropXX.jpg", where "XX" indicates
        the bounding box index
    - if image has bboxes from MegaDetector: "___cropXX_mdvY.Y.jpg", where
        "Y.Y" indicates the MegaDetector version
    See module docstring for more info and examples.

    Note: this function is very similar to the "download_and_crop()" function in
        crop_detections.py. The main difference is that this function uses
        MegaDB to look up Azure Storage container information for images based
        on the dataset, whereas the crop_detections.py version has no concept
        of a "dataset" and "ground-truth" bounding boxes from MegaDB.

    Args:
        queried_images_json: dict, represents JSON output of json_validator.py,
            all images in queried_images_json are assumed to have either ground
            truth or cached detected bounding boxes unless
            images_missing_detections is given
        detection_cache: dict, dataset_name => {img_path => detection_dict}
        detector_version: str, detector version string, e.g., '4.1'
        cropped_images_dir: str, path to folder where cropped images are saved
        confidence_threshold: float, only crop bounding boxes above this value
        save_full_images: bool, whether to save downloaded images to images_dir,
            images_dir must be given and must exist if save_full_images=True
        square_crops: bool, whether to crop bounding boxes as squares
        check_crops_valid: bool, whether to load each crop to ensure the file is
            valid (i.e., not truncated)
        images_dir: optional str, path to folder where full images are saved
        threads: int, number of threads to use for downloading images
        images_missing_detections: optional list of str, image files to skip
            because they have no ground truth or cached detected bounding boxes

    Returns: list of str, images with bounding boxes that failed to download or
        crop properly
    """
    # error checking before we download and crop any images
    valid_img_paths = set(queried_images_json.keys())
    if images_missing_detections is not None:
        valid_img_paths -= set(images_missing_detections)
    for img_path in valid_img_paths:
        info_dict = queried_images_json[img_path]
        ds, img_file = img_path.split('/', maxsplit=1)
        assert ds == info_dict['dataset']

        if 'bbox' in info_dict:  # ground-truth bounding boxes
            pass
        elif img_file in detection_cache[ds]:  # detected bounding boxes
            bbox_dicts = detection_cache[ds][img_file]['detections']
            assert all('conf' in bbox_dict for bbox_dict in bbox_dicts)
            # convert from category ID to category name
            for d in bbox_dicts:
                d['category'] = detection_categories[d['category']]
        else:
            raise ValueError(f'{img_path} has no ground truth bounding boxes '
                             'and was not found in the detection cache. Please '
                             'include it in images_missing_detections.')

    # we need the datasets table for getting SAS keys
    datasets_table = megadb_utils.MegadbUtils().get_datasets_table()
    container_clients = {}  # dataset name => ContainerClient

    pool = futures.ThreadPoolExecutor(max_workers=threads)
    future_to_img_path = {}
    images_failed_download = []

    print(f'Getting bbox info for {len(valid_img_paths)} images...')
    for img_path in tqdm(sorted(valid_img_paths)):
        # we already did all error checking above, so we don't do any here
        info_dict = queried_images_json[img_path]
        ds, img_file = img_path.split('/', maxsplit=1)

        # get ContainerClient
        if ds not in container_clients:
            sas_token = datasets_table[ds]['container_sas_key']
            if sas_token[0] == '?':
                sas_token = sas_token[1:]
            url = sas_blob_utils.build_azure_storage_uri(
                account=datasets_table[ds]['storage_account'],
                container=datasets_table[ds]['container'],
                sas_token=sas_token)
            container_clients[ds] = ContainerClient.from_container_url(url)
        container_client = container_clients[ds]

        # get bounding boxes
        # we must include the dataset <ds> in <crop_path_template> because
        #    '{img_path}' actually gets populated with <img_file> in
        #    load_and_crop()
        is_ground_truth = ('bbox' in info_dict)
        if is_ground_truth:  # ground-truth bounding boxes
            bbox_dicts = info_dict['bbox']
            crop_path_template = os.path.join(
                cropped_images_dir, ds, '{img_path}___crop{n:>02d}.jpg')
        else:  # detected bounding boxes
            bbox_dicts = detection_cache[ds][img_file]['detections']
            crop_path_template = os.path.join(
                cropped_images_dir, ds,
                '{img_path}___crop{n:>02d}_' + f'mdv{detector_version}.jpg')

        ds_dir = None if images_dir is None else os.path.join(images_dir, ds)

        # get the image, either from disk or from Blob Storage
        future = pool.submit(
            load_and_crop, img_file, ds_dir, container_client, bbox_dicts,
            confidence_threshold, crop_path_template, save_full_images,
            square_crops, check_crops_valid)
        future_to_img_path[future] = img_path

    total = len(future_to_img_path)
    total_downloads = 0
    total_new_crops = 0
    print(f'Reading/downloading {total} images and cropping...')
    for future in tqdm(futures.as_completed(future_to_img_path), total=total):
        img_path = future_to_img_path[future]
        try:
            did_download, num_new_crops = future.result()
            total_downloads += did_download
            total_new_crops += num_new_crops
        except Exception as e:  # pylint: disable=broad-except
            exception_type = type(e).__name__
            tqdm.write(f'{img_path} - generated {exception_type}: {e}')
            images_failed_download.append(img_path)

    pool.shutdown()
    for container_client in container_clients.values():
        # inelegant way to close the container_clients
        with container_client:
            pass

    print(f'Downloaded {total_downloads} images.')
    print(f'Made {total_new_crops} new crops.')
    return images_failed_download, total_downloads, total_new_crops
import json
import os

from tqdm import tqdm

from data_management.megadb import megadb_utils
import sas_blob_utils

images_dir = ''
queried_images_json_path = 'run_idfg2/queried_images.json'
output_dir = 'run_idfg2/'

with open(queried_images_json_path, 'r') as f:
    js = json.load(f)

datasets_table = megadb_utils.MegadbUtils().get_datasets_table()

output_files = {}

pbar = tqdm(js.items())
for img_path, img_info in pbar:
    save_path = os.path.join(images_dir, img_path)
    if os.path.exists(save_path):
        continue

    ds, img_file = img_path.split('/', maxsplit=1)
    if ds not in output_files:
        output_path = os.path.join(output_dir, f'{ds}_images.txt')
        output_files[ds] = open(output_path, 'w')

        dataset_info = datasets_table[ds]
def submit_batch_detection_api(images_to_detect: Iterable[str],
                               task_lists_dir: str,
                               detector_version: str,
                               account: str,
                               container: str,
                               sas_token: str,
                               caller: str,
                               batch_detection_api_url: str,
                               resume_file_path: str
                               ) -> Dict[str, List[Task]]:
    """
    Args:
        images_to_detect: list of str, list of str, image paths with the format
            <dataset-name>/<image-filename>
        task_lists_dir: str, path to local directory for saving JSON files
            each containing a list of image URLs corresponding to an API task
        detector_version: str, MegaDetector version string, e.g., '4.1',
            see {batch_detection_api_url}/supported_model_versions
        account: str, Azure Storage account name
        container: str, Azure Blob Storage container name, where the task lists
            will be uploaded
        sas_token: str, SAS token with write permissions for the container
        caller: str, allow-listed caller
        batch_detection_api_url: str, URL to batch detection API
        resume_file_path: str, path to save resume file

    Returns: dict, maps str dataset name to list of Task objects
    """
    filtered_images_to_detect = [
        x for x in images_to_detect if path_utils.is_image_file(x)]
    not_images = set(images_to_detect) - set(filtered_images_to_detect)
    if len(not_images) == 0:
        print('Good! All image files have valid file extensions.')
    else:
        print(f'Skipping {len(not_images)} files with non-image extensions:')
        pprint.pprint(sorted(not_images))
    images_to_detect = filtered_images_to_detect

    datasets_table = megadb_utils.MegadbUtils().get_datasets_table()

    images_by_dataset = split_images_list_by_dataset(images_to_detect)
    tasks_by_dataset = {}
    for dataset, image_paths in images_by_dataset.items():
        # get SAS URL for images container
        images_sas_token = datasets_table[dataset]['container_sas_key']
        if images_sas_token[0] == '?':
            images_sas_token = images_sas_token[1:]
        images_container_url = sas_blob_utils.build_azure_storage_uri(
            account=datasets_table[dataset]['storage_account'],
            container=datasets_table[dataset]['container'],
            sas_token=images_sas_token)

        # strip image paths of dataset name
        image_blob_names = [path[path.find('/') + 1:] for path in image_paths]

        tasks_by_dataset[dataset] = submit_batch_detection_api_by_dataset(
            dataset=dataset,
            image_blob_names=image_blob_names,
            images_container_url=images_container_url,
            task_lists_dir=task_lists_dir,
            detector_version=detector_version,
            account=account, container=container, sas_token=sas_token,
            caller=caller, batch_detection_api_url=batch_detection_api_url)

    # save list of dataset names and task IDs for resuming
    resume_json = [
        {
            'dataset': dataset,
            'task_name': task.name,
            'task_id': task.id,
            'local_images_list_path': task.local_images_list_path
        }
        for dataset in tasks_by_dataset
        for task in tasks_by_dataset[dataset]
    ]
    with open(resume_file_path, 'w') as f:
        json.dump(resume_json, f, indent=1)
    return tasks_by_dataset