Пример #1
0
def handle_root(dataset_prefix: str) -> None:
    """Handle writing a new dataset to the root catalog"""
    results = S3_CLIENT.list_objects(
        Bucket=ResourceName.STORAGE_BUCKET_NAME.value, Prefix=CATALOG_KEY)

    # create root catalog if it doesn't exist
    if CONTENTS_KEY in results:
        root_catalog = Catalog.from_file(
            f"{S3_URL_PREFIX}{ResourceName.STORAGE_BUCKET_NAME.value}/{CATALOG_KEY}"
        )

    else:
        root_catalog = Catalog(
            id=ROOT_CATALOG_ID,
            title=ROOT_CATALOG_TITLE,
            description=ROOT_CATALOG_DESCRIPTION,
            catalog_type=CatalogType.SELF_CONTAINED,
        )
        root_catalog.set_self_href(
            f"{S3_URL_PREFIX}{ResourceName.STORAGE_BUCKET_NAME.value}/{CATALOG_KEY}"
        )

    dataset_path = f"{S3_URL_PREFIX}{ResourceName.STORAGE_BUCKET_NAME.value}/{dataset_prefix}"
    dataset_catalog = Catalog.from_file(f"{dataset_path}/{CATALOG_KEY}")

    root_catalog.add_child(dataset_catalog,
                           strategy=GeostoreSTACLayoutStrategy())
    root_catalog.normalize_hrefs(
        f"{S3_URL_PREFIX}{ResourceName.STORAGE_BUCKET_NAME.value}",
        strategy=GeostoreSTACLayoutStrategy(),
    )

    root_catalog.save(catalog_type=CatalogType.SELF_CONTAINED)
Пример #2
0
def parse_stac(stac_uri: str) -> List[dict]:
    """Parse a STAC catalog JSON file to extract label URIs, images URIs,
    and AOIs.

    Note: This has been tested to be compatible with STAC version 1.0.0 but
    not any other versions.

    Args:
        stac_uri (str): Path to the STAC catalog JSON file.

    Returns:
        List[dict]: A lsit of dicts with keys: "label_uri", "image_uris",
            "label_bbox", "image_bbox", "bboxes_intersect", and "aoi_geometry".
            Each dict corresponds to one label item and its associated image
            assets in the STAC catalog.
    """
    setup_stac_io()
    cat = Catalog.from_file(stac_uri)
    version: str = cat.to_dict()['stac_version']

    if not version.startswith('1.0'):
        log.warning(f'Parsing is not guaranteed to work correctly for '
                    f'STAC version != 1.0.*. Found version: {version}.')

    cat.make_all_asset_hrefs_absolute()

    label_items = [item for item in cat.get_all_items() if is_label_item(item)]
    image_items = [get_linked_image_item(item) for item in label_items]

    if len(label_items) == 0:
        raise ValueError('Unable to find any label items in STAC catalog.')

    out = []
    for label_item, image_item in zip(label_items, image_items):
        label_uri: str = list(label_item.assets.values())[0].href
        label_bbox = box(*label_item.bbox)
        aoi_geometry: Optional[dict] = label_item.geometry

        if image_item is not None:
            image_assets = [
                asset for asset in image_item.get_assets().values()
                if 'image' in asset.media_type
            ]
            image_uris = [asset.href for asset in image_assets]
            image_bbox = box(*image_item.bbox)
            bboxes_intersect = label_bbox.intersects(image_bbox)
        else:
            image_uris = []
            image_bbox = None
            bboxes_intersect = False

        out.append({
            'label_uri': label_uri,
            'image_uris': image_uris,
            'label_bbox': label_bbox,
            'image_bbox': image_bbox,
            'bboxes_intersect': bboxes_intersect,
            'aoi_geometry': aoi_geometry
        })
    return out
Пример #3
0
    def test_map_assets_single(self):
        changed_asset = 'd43bead8-e3f8-4c51-95d6-e24e750a402b'

        def asset_mapper(key, asset):
            if key == changed_asset:
                asset.title = 'NEW TITLE'

            return asset

        with TemporaryDirectory() as tmp_dir:
            catalog = TestCases.test_case_2()

            new_cat = catalog.map_assets(asset_mapper)

            new_cat.normalize_hrefs(os.path.join(tmp_dir, 'cat'))
            new_cat.save(catalog_type=CatalogType.ABSOLUTE_PUBLISHED)

            result_cat = Catalog.from_file(
                os.path.join(tmp_dir, 'cat', 'catalog.json'))

            found = False
            for item in result_cat.get_all_items():
                for key, asset in item.assets.items():
                    if key == changed_asset:
                        found = True
                        self.assertEqual(asset.title, 'NEW TITLE')
                    else:
                        self.assertNotEqual(asset.title, 'NEW TITLE')
            self.assertTrue(found)
Пример #4
0
    def test_map_assets_tup(self):
        changed_assets = []

        def asset_mapper(key, asset):
            if 'geotiff' in asset.media_type:
                asset.title = 'NEW TITLE'
                changed_assets.append(key)
                return ('{}-modified'.format(key), asset)
            else:
                return asset

        with TemporaryDirectory() as tmp_dir:
            catalog = TestCases.test_case_2()

            new_cat = catalog.map_assets(asset_mapper)

            new_cat.normalize_hrefs(os.path.join(tmp_dir, 'cat'))
            new_cat.save(catalog_type=CatalogType.ABSOLUTE_PUBLISHED)

            result_cat = Catalog.from_file(
                os.path.join(tmp_dir, 'cat', 'catalog.json'))

            found = False
            not_found = False
            for item in result_cat.get_all_items():
                for key, asset in item.assets.items():
                    if key.replace('-modified', '') in changed_assets:
                        found = True
                        self.assertEqual(asset.title, 'NEW TITLE')
                    else:
                        not_found = True
                        self.assertNotEqual(asset.title, 'NEW TITLE')

            self.assertTrue(found)
            self.assertTrue(not_found)
Пример #5
0
def mirror_collections(url, path='', **kwargs):
    API_RELS = ['search', 'collections', 'next']

    cat = Catalog.from_file(url)

    empty_children = []
    total_items = 0
    for provider in cat.get_children():
        links = provider.get_child_links()
        if len(links):
            # remove API specific links
            [provider.remove_links(rel) for rel in API_RELS]
            # remove links from the collections
            for collection in provider.get_children():
                found = hits(
                    collection.get_single_link('items').get_href(),
                    {'limit': 0})
                [
                    collection.remove_links(rel)
                    for rel in ['child', 'next', 'items']
                ]
                logger.info(
                    f"{provider.id} - {collection.id}: {found} Items found")
                total_items += found
        else:
            empty_children.append(provider)

    [cat.remove_child(c.id) for c in empty_children]

    logger.info(f"{total_items} total Items found")
    cat.catalog_type = CatalogType.RELATIVE_PUBLISHED

    return cat
Пример #6
0
def lambda_handler(event, context):
    logger.debug('Event: %s' % json.dumps(event))

    root_cat = get_root_catalog()

    # check if collection and if so, add to Cirrus
    if 'extent' in event:
        # add to static catalog
        root_cat.add_child(event)

        # send to Cirrus Publish SNS
        response = snsclient.publish(TopicArn=PUBLISH_TOPIC,
                                     Message=json.dumps(event))
        logger.debug(f"SNS Publish response: {json.dumps(response)}")

    # check if URL to catalog
    if 'catalog_url' in event:
        cat = Catalog.from_file(event['catalog_url'])

        for child in cat.get_children():
            if isinstance(child, Collection):
                child.remove_links('child')
                link = Link('copied_from', child)
                child.add_link(link, child.get_self_href())
                root_cat.add_child(child)
                child_json = json.dumps(child.to_dict())
                logger.debug(f"Publishing {child.id}: {child_json}")
                response = snsclient.publish(TopicArn=PUBLISH_TOPIC,
                                             Message=child_json)
                logger.debug(f"SNS Publish response: {json.dumps(response)}")

    root_cat.normalize_and_save(ROOT_URL, CatalogType.ABSOLUTE_PUBLISHED)
Пример #7
0
async def cli():
    args = parse_args(sys.argv[1:])

    logging.basicConfig(
        stream=sys.stdout,
        level=logging.DEBUG)  #, format='%(asctime)-15s %(message)s')
    # quiet loggers
    for lg in [
            'httpx', 'urllib3', 'botocore', 'boto3', 'aioboto3', 'aiobotocore'
    ]:
        logging.getLogger(lg).propagate = False

    cmd = args.pop('command')
    if cmd == 'create':
        # create initial catalog through to collections
        cat = mirror_collections(args['url'], args['path'])
        cat.normalize_hrefs(args['path'])
        await cat.save()
    elif cmd == 'update':
        cat = Catalog.from_file(args['cat'], )
        collection = cat.get_child(args['provider']).get_child(
            args['collection'])

        url = f"{args['url']}/{args['provider']}/collections/{args['collection']}/items"
        params = {
            'limit': args['limit'],
        }
        if args['datetime'] is not None:
            params['datetime'] = args['datetime']
        await mirror_items(collection,
                           url,
                           params,
                           max_sync_queries=args['max_sync_queries'],
                           item_template=args['item_template'])
Пример #8
0
def parse_stac(stac_uri):
    setup_stac_s3()
    cat = Catalog.from_file(stac_uri)
    cat.make_all_asset_hrefs_absolute()
    labels_uri = None
    geotiff_uris = []
    for item in cat.get_all_items():
        if isinstance(item, LabelItem):
            labels_uri = list(item.assets.values())[0].href
            labels_box = box(*item.bbox)

    # only use geotiffs that intersect with bbox of labels
    for item in cat.get_all_items():
        if not isinstance(item, LabelItem):
            geotiff_uri = list(item.assets.values())[0].href
            geotiff_box = box(*item.bbox)
            if labels_box.intersects(geotiff_box):
                geotiff_uri = geotiff_uri.replace('%7C', '|')
                geotiff_uris.append(geotiff_uri)

    if not labels_uri:
        raise ValueError('Unable to read labels URI from STAC.')
    if not geotiff_uris:
        raise ValueError('Unable to read GeoTIFF URIs from STAC.')
    return labels_uri, labels_box, geotiff_uris
Пример #9
0
 def test_case_4():
     """Test case that is based on a local copy of the Tier 1 dataset from
     DrivenData's OpenCities AI Challenge.
     See: https://www.drivendata.org/competitions/60/building-segmentation-disaster-resilience
     """
     return Catalog.from_file(
         TestCases.get_path('data-files/catalogs/test-case-4/catalog.json'))
Пример #10
0
def handler(event, context={}):
    # if this is batch, output to stdout
    if not hasattr(context, "invoked_function_arn"):
        logger.addHandler(logging.StreamHandler())

    logger.debug('Event: %s' % json.dumps(event))

    # parse input
    url = event.get('url')
    batch = event.get('batch', False)
    process = event['process']

    if batch and hasattr(context, "invoked_function_arn"):
        submit_batch_job(event,
                         context.invoked_function_arn,
                         definition='lambda-as-batch',
                         name='feed-stac-crawl')
        return

    cat = Catalog.from_file(url)

    for item in cat.get_all_items():
        payload = {
            'type': 'FeatureCollection',
            'features': [item.to_dict()],
            'process': process
        }
        SNS_CLIENT.publish(TopicArn=SNS_TOPIC, Message=json.dumps(payload))
Пример #11
0
    def test_read_remote(self):
        catalog_url = ('https://raw.githubusercontent.com/radiantearth/stac-spec/'
                       'v{}'
                       '/extensions/label/examples/multidataset/catalog.json'.format(STAC_VERSION))
        cat = Catalog.from_file(catalog_url)

        zanzibar = cat.get_child('zanzibar-collection')

        self.assertEqual(len(list(zanzibar.get_items())), 2)
Пример #12
0
    def test_read_remote(self):
        # TODO: Move this URL to the main stac-spec repo once the example JSON is fixed.
        catalog_url = (
            'https://raw.githubusercontent.com/lossyrob/stac-spec/0.9.0/pystac-upgrade-fixes'
            '/extensions/label/examples/multidataset/catalog.json')
        cat = Catalog.from_file(catalog_url)

        zanzibar = cat.get_child('zanzibar-collection')

        self.assertEqual(len(list(zanzibar.get_items())), 2)
Пример #13
0
def get_scenes(json_file: str,
               class_config: ClassConfig,
               class_id_filter_dict: dict,
               catalog_dir: str,
               imagery_dir: str,
               train_crops: List[CropOffsets] = [],
               val_crops: List[CropOffsets] = [],
               N: int = None) -> Tuple[List[SceneConfig], List[SceneConfig]]:

    train_scenes = []
    val_scenes = []
    with open(json_file, 'r') as f:
        for catalog_imagery in json.load(f):
            catalog = catalog_imagery.get('catalog')
            catalog = catalog.strip()
            catalog = f'{catalog_dir}/{catalog}'
            catalog = catalog.replace('s3://', '/vsizip/vsis3/')
            (labelss, imagerys) = hrefs_from_catalog(
                Catalog.from_file(root_of_tarball(catalog)), N)
            imagery_name = imagery = catalog_imagery.get('imagery')
            if imagery_name is not None:
                imagery = imagery.strip()
                imagery = f'{imagery_dir}/{imagery}'
                if '.zip' in imagery:
                    imagery = imagery.replace('s3://', '/vsizip/vsis3/')
                else:
                    imagery = imagery.replace('s3://', '/vsis3/')
                imagerys = [imagery] * len(labelss)
            else:
                imagerys = map(lambda i: i.replace('_rgb', ''), imagerys)
                if not imagery_dir.endswith('/'):
                    imagery_dir = imagery_dir + '/'
                imagerys = list(
                    map(lambda i: i.replace('./', imagery_dir), imagerys))
            h = hashlib.sha256(catalog.encode()).hexdigest()
            del imagery
            print('imagery', imagerys)
            print('labels', labelss)

            make_scene = partial(hrefs_to_sceneconfig,
                                 class_id_filter_dict=class_id_filter_dict)
            for j, (labels, imagery) in enumerate(zip(labelss, imagerys)):
                for i, crop in enumerate(train_crops):
                    scene = make_scene(name=f'{h}-train-{i}-{j}',
                                       extent_crop=crop,
                                       imagery=imagery,
                                       labels=labels)
                    train_scenes.append(scene)
                for i, crop in enumerate(val_crops):
                    scene = make_scene(name=f'{h}-val-{i}-{j}',
                                       extent_crop=crop,
                                       imagery=imagery,
                                       labels=labels)
                    val_scenes.append(scene)
    return train_scenes, val_scenes
Пример #14
0
 def test_when_a_service_completes_it_writes_a_output_catalog_to_the_output_dir(
         self):
     with cli_parser('--harmony-action', 'invoke', '--harmony-input',
                     '{"test": "input"}', '--harmony-sources',
                     'example/source/catalog.json',
                     '--harmony-metadata-dir', self.workdir) as parser:
         args = parser.parse_args()
         cli.run_cli(parser, args, MockAdapter, cfg=self.config)
         output = Catalog.from_file(
             os.path.join(self.workdir, 'catalog.json'))
         self.assertTrue(output.validate)
Пример #15
0
    def test_reading_iterating_and_writing_works_as_expected(self):
        """ Test case to cover issue #88 """
        stac_uri = 'tests/data-files/catalogs/test-case-6/catalog.json'
        cat = Catalog.from_file(stac_uri)

        # Iterate over the items. This was causing failure in
        # in the later iterations as per issue #88
        for item in cat.get_all_items():
            pass

        with TemporaryDirectory() as tmp_dir:
            new_stac_uri = os.path.join(tmp_dir, 'test-case-6')
            cat.normalize_hrefs(new_stac_uri)
            cat.save(catalog_type=CatalogType.SELF_CONTAINED)

            # Open the local copy and iterate over it.
            cat2 = Catalog.from_file(os.path.join(new_stac_uri, 'catalog.json'))

            for item in cat2.get_all_items():
                # Iterate again over the items. This would fail in #88
                pass
Пример #16
0
    def test_validate_label(self):
        with open(self.label_example_1_uri) as f:
            label_example_1_dict = json.load(f)
        pystac.validation.validate_dict(label_example_1_dict, "ITEM")

        with TemporaryDirectory() as tmp_dir:
            cat_dir = os.path.join(tmp_dir, 'catalog')
            catalog = TestCases.test_case_1()
            catalog.normalize_and_save(cat_dir, catalog_type=CatalogType.SELF_CONTAINED)

            cat_read = Catalog.from_file(os.path.join(cat_dir, 'catalog.json'))
            label_item_read = cat_read.get_item("area-2-2-labels", recursive=True)
            label_item_read.validate()
Пример #17
0
def get_root_catalog() -> Dict:
    """Get Cirrus root catalog from s3

    Returns:
        Dict: STAC root catalog
    """
    if s3().exists(ROOT_URL):
        cat = Catalog.from_file(ROOT_URL)
    else:
        catid = DATA_BUCKET.split('-data-')[0]
        cat = Catalog(id=catid, description=DESCRIPTION)
    logger.debug(f"Fetched {cat.describe()}")
    return cat
Пример #18
0
    def test_set_hrefs_manually(self):
        catalog = TestCases.test_case_1()

        # Modify the datetimes
        year = 2004
        month = 2
        for item in catalog.get_all_items():
            item.datetime = item.datetime.replace(year=year, month=month)
            year += 1
            month += 1

        with TemporaryDirectory() as tmp_dir:
            for root, _, items in catalog.walk():

                # Set root's HREF based off the parent
                parent = root.get_parent()
                if parent is None:
                    root_dir = tmp_dir
                else:
                    d = os.path.dirname(parent.get_self_href())
                    root_dir = os.path.join(d, root.id)
                root_href = os.path.join(root_dir, root.DEFAULT_FILE_NAME)
                root.set_self_href(root_href)

                # Set each item's HREF based on it's datetime
                for item in items:
                    item_href = '{}/{}-{}/{}.json'.format(
                        root_dir, item.datetime.year, item.datetime.month,
                        item.id)
                    item.set_self_href(item_href)

            catalog.save(catalog_type=CatalogType.SELF_CONTAINED)

            read_catalog = Catalog.from_file(
                os.path.join(tmp_dir, 'catalog.json'))

            for root, _, items in read_catalog.walk():
                parent = root.get_parent()
                if parent is None:
                    self.assertEqual(root.get_self_href(),
                                     os.path.join(tmp_dir, 'catalog.json'))
                else:
                    d = os.path.dirname(parent.get_self_href())
                    self.assertEqual(
                        root.get_self_href(),
                        os.path.join(d, root.id, root.DEFAULT_FILE_NAME))
                for item in items:
                    end = '{}-{}/{}.json'.format(item.datetime.year,
                                                 item.datetime.month, item.id)
                    self.assertTrue(item.get_self_href().endswith(end))
Пример #19
0
def get_item(catalog):

    cat = Catalog.from_file(catalog)

    try:

        collection = next(cat.get_children())
        item = next(collection.get_items())

    except StopIteration:

        item = next(cat.get_items())

    return item
Пример #20
0
    def test_create_and_read(self):
        with TemporaryDirectory() as tmp_dir:
            cat_dir = os.path.join(tmp_dir, 'catalog')
            catalog = TestCases.test_case_1()

            catalog.normalize_and_save(cat_dir, catalog_type=CatalogType.ABSOLUTE_PUBLISHED)

            read_catalog = Catalog.from_file('{}/catalog.json'.format(cat_dir))

            collections = catalog.get_children()
            self.assertEqual(len(list(collections)), 2)

            items = read_catalog.get_all_items()

            self.assertEqual(len(list(items)), 8)
Пример #21
0
def lambda_handler(event, context={}):
    logger.debug('Event: %s' % json.dumps(event))

    # check if collection and if so, add to Cirrus
    if 'extent' in event:
        stac.add_collections([Collection.from_dict(event)])

    # check if URL to catalog - ingest all collections
    if 'catalog_url' in event:
        collections = []
        cat = Catalog.from_file(event['catalog_url'])
        for child in cat.get_children():
            if isinstance(child, Collection):
                collections.append(child)
        stac.add_collections(collections)
Пример #22
0
def get_root_catalog():
    """Get Cirrus root catalog from s3

    Returns:
        Dict: STAC root catalog
    """
    caturl = f"{ROOT_URL}/catalog.json"
    if s3().exists(caturl):
        cat = Catalog.from_file(caturl)
    else:
        catid = DATA_BUCKET.split('-data-')[0]
        cat = Catalog(id=catid, description=DESCRIPTION)
        cat.normalize_and_save(ROOT_URL, CatalogType.ABSOLUTE_PUBLISHED)
    logger.debug(f"Fetched {cat.describe()}")
    return cat
Пример #23
0
def handle_dataset(version_metadata_key: str) -> None:
    """Handle writing a new dataset version to the dataset catalog"""
    storage_bucket_path = f"{S3_URL_PREFIX}{ResourceName.STORAGE_BUCKET_NAME.value}"
    dataset_prefix = version_metadata_key.split("/", maxsplit=1)[0]
    dataset_catalog = Catalog.from_file(
        f"{storage_bucket_path}/{dataset_prefix}/{CATALOG_KEY}")

    dataset_version_metadata = STAC_IO.read_stac_object(
        f"{storage_bucket_path}/{version_metadata_key}")

    dataset_catalog.add_child(dataset_version_metadata,
                              strategy=GeostoreSTACLayoutStrategy())

    dataset_catalog.normalize_hrefs(f"{storage_bucket_path}/{dataset_prefix}",
                                    strategy=GeostoreSTACLayoutStrategy())
    dataset_catalog.save(catalog_type=CatalogType.SELF_CONTAINED)
Пример #24
0
    def test_validate_label(self) -> None:
        with open(self.label_example_1_uri, encoding="utf-8") as f:
            label_example_1_dict = json.load(f)
        pystac.validation.validate_dict(label_example_1_dict,
                                        pystac.STACObjectType.ITEM)

        with tempfile.TemporaryDirectory() as tmp_dir:
            cat_dir = os.path.join(tmp_dir, "catalog")
            catalog = TestCases.test_case_1()
            catalog.normalize_and_save(cat_dir,
                                       catalog_type=CatalogType.SELF_CONTAINED)

            cat_read = Catalog.from_file(os.path.join(cat_dir, "catalog.json"))
            label_item_read = cat_read.get_item("area-2-2-labels",
                                                recursive=True)
            assert label_item_read is not None
            label_item_read.validate()
Пример #25
0
    def test_validate_label(self):
        sv = SchemaValidator()
        with open(self.label_example_1_uri) as f:
            label_example_1_dict = json.load(f)
        sv.validate_dict(label_example_1_dict, LabelItem)

        with TemporaryDirectory() as tmp_dir:
            cat_dir = os.path.join(tmp_dir, 'catalog')
            catalog = TestCases.test_case_1()
            label_item = LabelItem.from_dict(label_example_1_dict)
            catalog.add_item(label_item)
            catalog.normalize_and_save(cat_dir,
                                       catalog_type=CatalogType.SELF_CONTAINED)

            cat_read = Catalog.from_file(os.path.join(cat_dir, 'catalog.json'))
            label_item_read = cat_read.get_item("label-example-1-label-item")
            sv = SchemaValidator()
            sv.validate_object(label_item_read)
Пример #26
0
    def __init__(self, csv_file, cat_path, root_dir, transform=None):
        """Initialisation of the dataset
		
		Create a dataset object with all required informations.
		
		Arguments:
			csv_file {str} -- [Path to the csv file with the dataset map]
			root_dir {str} -- [Root directorie of the images]
			cat_path {str} -- [Path of the catalog file]
		
		Keyword Arguments:
			transform {callable, optional} -- [Optional transform to be applied on a sample.] (default: {None})
		"""
        self.houses_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        self.train_cat = Catalog.from_file(cat_path)
        self.cols = {cols.id: cols for cols in self.train_cat.get_children()}

        self.items_labels = {}

        items_size = 0

        for col in list(self.cols.keys()):
            self.items_labels[str(col)] = {}
            for ii in self.cols[col].get_all_items():
                print("Init dataset : ", col, ii.id)
                if "-labels" in str(ii.id):
                    #rasterio.open(ii.make_asset_hrefs_absolute().assets['labels'].href).meta
                    gpd.read_file(
                        ii.make_asset_hrefs_absolute().assets['labels'].href)

                    one_item_label = self.cols[col].get_item(id=ii.id)
                    scene_labels_gdf = gpd.read_file(
                        one_item_label.assets['labels'].href)
                    self.items_labels[str(col)][str(ii.id)] = scene_labels_gdf

                    items_size += sys.getsizeof(scene_labels_gdf)
                    print("Size sum :", convert_bytes(items_size))

                else:
                    rasterio.open(ii.make_asset_hrefs_absolute().
                                  assets['image'].href).meta
def main():
    args = parse_args()
    
    data_path = Path(args.data_path)
    
    to_save = data_path / f'test_tiles_{args.win_sz}'
    if not to_save.exists():
        to_save.mkdir(parents=True)

    test_cat = Catalog.from_file(str(data_path / 'test' / 'catalog.json'))
    with tqdm.tqdm(test_cat.get_items()) as pbar:
        for one_item in pbar:
            rst = rasterio.open(one_item.make_asset_hrefs_absolute().assets['image'].href)
            win_arr = rst.read()
            win_arr = np.transpose(win_arr, (1, 2, 0))[..., :3]
            win_arr = win_arr[..., ::-1]
            res = cv2.imwrite(str(to_save / f'{one_item.id}.jpg'), win_arr)
            assert res, f'Cannot write image {one_item.id}'
            rst.close()
Пример #28
0
def get_config(runner, root_uri, catalog_root, epochs='20', batch_size='8'):
    # Read STAC catalog
    catalog: Catalog = Catalog.from_file(pystac_workaround(catalog_root))

    # TODO: pull desired channels from root collection properties
    channel_ordering: [int] = [0, 1, 2]

    # TODO: pull ClassConfig info from root collection properties
    class_config: ClassConfig = ClassConfig(names=["land", "water"],
                                            colors=["brown", "blue"])

    dataset = build_dataset_from_catalog(catalog, channel_ordering,
                                         class_config)

    chip_sz = 512

    backend = PyTorchSemanticSegmentationConfig(
        model=SemanticSegmentationModelConfig(backbone=Backbone.resnet50),
        solver=SolverConfig(lr=1e-4,
                            num_epochs=int(epochs),
                            batch_sz=int(batch_size),
                            ignore_last_class=True),
        log_tensorboard=False,
        run_tensorboard=False,
        num_workers=0,
    )
    chip_options = SemanticSegmentationChipOptions(
        window_method=SemanticSegmentationWindowMethod.sliding,
        target_class_ids=[1],
        negative_survival_prob=0.125,
        stride=chip_sz // 2)

    return SemanticSegmentationConfig(
        root_uri=root_uri,
        dataset=dataset,
        backend=backend,
        train_chip_sz=chip_sz,
        predict_chip_sz=chip_sz,
        chip_options=chip_options,
        img_format='npy',
        label_format='png',
    )
Пример #29
0
    def test_map_items(self):
        def item_mapper(item):
            item.properties['ITEM_MAPPER'] = 'YEP'
            return item

        with TemporaryDirectory() as tmp_dir:
            catalog = TestCases.test_case_1()

            new_cat = catalog.map_items(item_mapper)

            new_cat.normalize_hrefs(os.path.join(tmp_dir, 'cat'))
            new_cat.save(catalog_type=CatalogType.ABSOLUTE_PUBLISHED)

            result_cat = Catalog.from_file(os.path.join(tmp_dir, 'cat', 'catalog.json'))

            for item in result_cat.get_all_items():
                self.assertTrue('ITEM_MAPPER' in item.properties)

            for item in catalog.get_all_items():
                self.assertFalse('ITEM_MAPPER' in item.properties)
Пример #30
0
def _build_adapter(AdapterClass, message_string, sources_path, data_location,
                   config):
    """
    Creates the adapter to be invoked for the given harmony input

    Parameters
    ----------
    AdapterClass : class
        The BaseHarmonyAdapter subclass to use to handle service invocations
    message_string : string
        The Harmony input message
    sources_path : string
        A file location containing a STAC catalog corresponding to the input message sources
    data_location : string
        The name of the directory where output should be written
    config : harmony.util.Config
        A configuration instance for this service
    Returns
    -------
        BaseHarmonyAdapter subclass instance
            The adapter to be invoked
    """
    catalog = Catalog.from_file(sources_path)
    secret_key = config.shared_secret_key

    if bool(secret_key):
        decrypter = create_decrypter(bytes(secret_key, 'utf-8'))
    else:

        def identity(arg):
            return arg

        decrypter = identity

    message = Message(json.loads(message_string), decrypter)
    if data_location:
        message.stagingLocation = data_location
    adapter = AdapterClass(message, catalog)
    adapter.set_config(config)

    return adapter