예제 #1
0
def get_example_images(megadb_utils: MegadbUtils, dataset_name: str,
                       class_name: str) -> List[Optional[str]]:
    """Gets SAS URLs for images of a particular class from a given dataset."""
    datasets_table = megadb_utils.get_datasets_table()

    # this query should be fairly fast, ~1 sec
    query_both_levels = f'''
    SELECT TOP {NUMBER_SEQUENCES_TO_QUERY} VALUE seq
    FROM seq
    WHERE ARRAY_CONTAINS(seq.class, "{class_name}")
        OR (SELECT VALUE COUNT(im)
            FROM im IN seq.images
            WHERE ARRAY_CONTAINS(im.class, "{class_name}")) > 0
    '''
    sequences = megadb_utils.query_sequences_table(
        query_both_levels, partition_key=dataset_name)

    num_samples = min(len(sequences), NUMBER_EXAMPLES_PER_SPECIES)
    sample_seqs = sample(sequences, num_samples)

    image_urls: List[Optional[str]] = []
    for seq in sample_seqs:
        sample_image = sample(seq['images'], 1)[0]  # sample 1 img per sequence
        img_path = MegadbUtils.get_full_path(
            datasets_table, dataset_name, sample_image['file'])
        img_path = urllib.parse.quote_plus(img_path)

        dataset_info = datasets_table[dataset_name]
        img_url = 'https://{}.blob.core.windows.net/{}/{}{}'.format(
            dataset_info["storage_account"],
            dataset_info["container"],
            img_path,
            dataset_info["container_sas_key"])
        image_urls.append(img_url)

    num_missing = NUMBER_EXAMPLES_PER_SPECIES - len(image_urls)
    if num_missing > 0:
        image_urls.extend([None] * num_missing)
    assert len(image_urls) == NUMBER_EXAMPLES_PER_SPECIES
    return image_urls
예제 #2
0
def query_species_by_dataset(megadb_utils: MegadbUtils,
                             output_dir: str) -> None:
    """For each dataset, creates a JSON file specifying species counts.

    Skips dataset if a JSON file for it already exists.
    """
    # which datasets are already processed?
    queried_datasets = set(
        i.split('.json')[0] for i in os.listdir(output_dir)
        if i.endswith('.json'))

    datasets_table = megadb_utils.get_datasets_table()
    dataset_names = [i for i in datasets_table if i not in queried_datasets]

    print(
        f'{len(queried_datasets)} datasets already queried. Querying species '
        f'in {len(dataset_names)} datasets...')

    for dataset_name in dataset_names:
        print(f'Querying dataset {dataset_name}...')

        # sequence-level query should be fairly fast, ~1 sec
        query_seq_level = '''
        SELECT VALUE seq.class
        FROM seq
        WHERE ARRAY_LENGTH(seq.class) > 0
            AND NOT ARRAY_CONTAINS(seq.class, "empty")
            AND NOT ARRAY_CONTAINS(seq.class, "__label_unavailable")
        '''
        results = megadb_utils.query_sequences_table(
            query_seq_level, partition_key=dataset_name)

        counter = Counter()
        for i in results:
            counter.update(i)

        # cases when the class field is on the image level (images in a sequence
        # that had different class labels, 'caltech' dataset is like this)
        # this query may take a long time, >1hr
        query_image_level = '''
        SELECT VALUE seq.images
        FROM sequences seq
        WHERE (
            SELECT VALUE COUNT(im)
            FROM im IN seq.images
            WHERE ARRAY_LENGTH(im.class) > 0
        ) > 0
        '''

        start = datetime.now()
        results_im = megadb_utils.query_sequences_table(
            query_image_level, partition_key=dataset_name)
        elapsed = (datetime.now() - start).seconds
        print(f'- image-level query took {elapsed}s')

        for seq_images in results_im:
            for im in seq_images:
                assert 'class' in im
                counter.update(im['class'])

        with open(os.path.join(output_dir, f'{dataset_name}.json'), 'w') as f:
            json.dump(counter, f, indent=2)
예제 #3
0
# Use False if do not want all results stored in a single JSON.
consolidate_results = True


#%% Script

time_stamp = datetime.utcnow().strftime('%Y%m%d%H%M%S')

db_utils = MegadbUtils()  # read the CosmosDB endpoint and key from the environment

# execute the query
start_time = time.time()

result_iterable = db_utils.query_sequences_table(query=query,
                                                 partition_key=partition_key,
                                                 parameters=query_parameters)

# loop through and save the results
results = []
item_count = 0
part_count = 0
part_paths = []

for item in result_iterable:
    # MODIFY HERE depending on the query
    item_processed = {k: v for k, v in item.items() if not k.startswith('_')}

    results.append(item_processed)
    item_count += 1