Example #1
0
    def test_file_exists(self):
        path = os.path.join(self.tmp_dir.name, 'lorem', 'ipsum.txt')
        s3_path = 's3://{}/xxx/lorem.txt'.format(self.bucket_name)
        s3_path_prefix = 's3://{}/xxx/lorem'.format(self.bucket_name)
        s3_directory = 's3://{}/xxx/'.format(self.bucket_name)
        make_dir(path, check_empty=False, use_dirname=True)

        str_to_file(self.lorem, path)
        upload_or_copy(path, s3_path)

        self.assertTrue(file_exists(s3_directory, include_dir=True))
        self.assertTrue(file_exists(s3_path, include_dir=False))
        self.assertFalse(file_exists(s3_path_prefix, include_dir=True))
        self.assertFalse(file_exists(s3_directory, include_dir=False))
        self.assertFalse(
            file_exists(s3_directory + 'NOTPOSSIBLE', include_dir=False))
    def __init__(self,
                 uri,
                 extent,
                 crs_transformer,
                 tmp_dir,
                 vector_output=None,
                 class_config=None):
        """Constructor.

        Args:
            uri: (str) URI of GeoTIFF file used for storing predictions as RGB values
            extent: (Box) The extent of the scene
            crs_transformer: (CRSTransformer)
            tmp_dir: (str) temp directory to use
            vector_output: (None or array of dicts) containing vectorifiction
                configuration information
            class_config: (ClassConfig) with color values used to convert
                class ids to RGB value
        """
        self.uri = uri
        self.vector_output = vector_output
        self.extent = extent
        self.crs_transformer = crs_transformer
        self.tmp_dir = tmp_dir
        # Note: can't name this class_transformer due to Python using that attribute
        if class_config:
            self.class_trans = SegmentationClassTransformer(class_config)
        else:
            self.class_trans = None

        self.source = None
        if file_exists(uri):
            self.source = RasterioSourceConfig(uris=[uri]).build(tmp_dir)
Example #3
0
    def test_file_exists_local_true(self):
        path = os.path.join(self.tmp_dir.name, 'lorem', 'ipsum.txt')
        directory = os.path.dirname(path)
        make_dir(directory, check_empty=False)

        str_to_file(self.lorem, path)

        self.assertTrue(file_exists(path))
Example #4
0
    def test_file_exists_s3_true(self):
        path = os.path.join(self.tmp_dir.name, 'lorem', 'ipsum.txt')
        directory = os.path.dirname(path)
        make_dir(directory, check_empty=False)

        str_to_file(self.lorem, path)

        s3_path = 's3://{}/lorem.txt'.format(self.bucket_name)
        upload_or_copy(path, s3_path)

        self.assertTrue(file_exists(s3_path))
Example #5
0
    def test_copy_from_http(self):
        http_path = ('https://raw.githubusercontent.com/tensorflow/models/'
                     '17fa52864bfc7a7444a8b921d8a8eb1669e14ebd/README.md')
        expected = os.path.join(
            self.tmp_dir.name, 'http', 'raw.githubusercontent.com',
            'tensorflow/models',
            '17fa52864bfc7a7444a8b921d8a8eb1669e14ebd/README.md')
        download_if_needed(http_path, self.tmp_dir.name)

        self.assertTrue(file_exists(expected))
        os.remove(expected)
Example #6
0
 def test_file_exists_s3_false(self):
     s3_path = 's3://{}/hello.txt'.format(self.bucket_name)
     self.assertFalse(file_exists(s3_path))
Example #7
0
    def test_file_exists_local_false(self):
        path = os.path.join(self.tmp_dir.name, 'hello', 'hello.txt')
        directory = os.path.dirname(path)
        make_dir(directory, check_empty=False)

        self.assertFalse(file_exists(path))
Example #8
0
 def test_file_exists_http_false(self):
     http_path = ('https://raw.githubusercontent.com/tensorflow/models/'
                  '17fa52864bfc7a7444a8b921d8a8eb1669e14ebd/XXX')
     self.assertFalse(file_exists(http_path))
Example #9
0
def save_image_crop(image_uri,
                    image_crop_uri,
                    label_uri=None,
                    label_crop_uri=None,
                    size=600,
                    min_features=10,
                    vector_labels=True,
                    class_config=None):
    """Save a crop of an image to use for testing.

    If label_uri is set, the crop needs to cover >= min_features.

    Args:
        image_uri: URI of original image
        image_crop_uri: URI of cropped image to save
        label_uri: optional URI of label file
        label_crop_uri: optional URI of cropped labels to save
        size: height and width of crop

    Raises:
        ValueError if cannot find a crop satisfying min_features constraint.
    """
    if not file_exists(image_crop_uri):
        print('Saving test crop to {}...'.format(image_crop_uri))
        old_environ = os.environ.copy()
        try:
            request_payer = S3FileSystem.get_request_payer()
            if request_payer == 'requester':
                os.environ['AWS_REQUEST_PAYER'] = request_payer
            im_dataset = rasterio.open(image_uri)
            h, w = im_dataset.height, im_dataset.width

            extent = Box(0, 0, h, w)
            windows = extent.get_windows(size, size)
            if label_uri and vector_labels:
                crs_transformer = RasterioCRSTransformer.from_dataset(
                    im_dataset)
                geojson_vs_config = GeoJSONVectorSourceConfig(uri=label_uri)
                vs = geojson_vs_config.build(class_config, crs_transformer)
                geojson = vs.get_geojson()
                geoms = []
                for f in geojson['features']:
                    g = shape(f['geometry'])
                    geoms.append(g)
                tree = STRtree(geoms)

            def p2m(x, y, z=None):
                return crs_transformer.pixel_to_map((x, y))

            for w in windows:
                use_window = True
                if label_uri and vector_labels:
                    w_polys = tree.query(w.to_shapely())
                    use_window = len(w_polys) >= min_features
                    if use_window and label_crop_uri is not None:
                        print('Saving test crop labels to {}...'.format(
                            label_crop_uri))

                        label_crop_features = [
                            mapping(transform(p2m, wp)) for wp in w_polys
                        ]
                        label_crop_json = {
                            'type':
                            'FeatureCollection',
                            'features': [{
                                'geometry': f
                            } for f in label_crop_features]
                        }
                        json_to_file(label_crop_json, label_crop_uri)

                if use_window:
                    crop_image(image_uri, w, image_crop_uri)

                    if not vector_labels and label_uri and label_crop_uri:
                        crop_image(label_uri, w, label_crop_uri)

                    break

            if not use_window:
                raise ValueError('Could not find a good crop.')
        finally:
            os.environ.clear()
            os.environ.update(old_environ)