コード例 #1
0
 def test_resize_bbox(self):
     my_mask = SubsetMask(huc10190004.get('conus1_mask'))
     self.assertSequenceEqual((716, 745, 1039, 1123), my_mask.inner_mask_edges)
     self.assertSequenceEqual((716, 745, 1039, 1123), my_mask.bbox_edges)
     self.assertSequenceEqual((30, 85), my_mask.bbox_shape)
     self.assertSequenceEqual((30, 85), my_mask.inner_mask_shape)
     my_mask.add_bbox_to_mask(padding=(9, 9, 9, 9))
     self.assertSequenceEqual((707, 754, 1030, 1132), my_mask.bbox_edges)
     self.assertSequenceEqual((716, 745, 1039, 1123), my_mask.inner_mask_edges)
     self.assertSequenceEqual((30, 85), my_mask.inner_mask_shape)
     self.assertSequenceEqual((48, 103), my_mask.bbox_shape)
     self.assertEqual(-999, my_mask.no_data_value)
     self.assertEqual(0, my_mask.bbox_val)
     self.assertSequenceEqual((9, 9, 9, 9), my_mask.get_bbox().get_padding())
     my_mask.write_mask_to_tif("WBDHU8_conus1_mask_padded.tif")
コード例 #2
0
ファイル: rasterizer.py プロジェクト: vineetbansal/Subsetting
class ShapefileRasterizer:
    """Class for converting shapefile to raster for use as mask"""
    def __init__(self,
                 input_path,
                 shapefile_name,
                 reference_dataset,
                 no_data=NO_DATA,
                 output_path='.'):
        """
        Parameters
        ----------
        input_path : str
            path to input files (shapefile set)
        shapefile_name : str
            name of shapefile dataset
        reference_dataset : gdal.dataset
            gdal dataset defining the overall domain
        no_data : int, optional
            value to write for no_data cells (default -999)
        output_path : str, optional
            where to write the outputs (default '.')

        Returns
        -------
        ShapefileRasterizer
        """
        if no_data in [0, 1]:
            raise Exception(
                f'ShapfileRasterizer: '
                f'Do not used reserved values 1 or 0 for no_data value: got no_data={no_data}'
            )
        self.shapefile_path = input_path
        self.output_path = output_path
        self.shapefile_name = shapefile_name
        self.ds_ref = reference_dataset
        self.no_data = no_data
        # TODO Handle shape extension using Pathlib.path
        self.full_shapefile_path = os.path.join(
            self.shapefile_path, '.'.join((self.shapefile_name, 'shp')))
        self.check_shapefile_parts()
        self.subset_mask = None

    def __repr__(self):
        return f"{self.__class__.__name__}(shapefile_path:{self.shapefile_path!r}, " \
               f"shapefile_name:{self.shapefile_name!r}, output_path:{self.output_path!r}, ds_ref:{self.ds_ref!r}, " \
               f"no_data:{self.no_data!r}, full_shapefile_path:{self.full_shapefile_path!r}, " \
               f"subset_mask:{self.subset_mask!r}"

    def check_shapefile_parts(self):
        """verify the required parts of a shapefile are present in the same folder
        logs a warning
        Returns
        -------
        None
        """
        shape_parts = [
            ".".join((self.shapefile_name, ext))
            for ext in ['shp', 'dbf', 'prj', 'shx']
        ]
        for shp_component_file in shape_parts:
            if not os.path.isfile(
                    os.path.join(self.shapefile_path, shp_component_file)):
                logging.warning(f'Shapefile path missing {shp_component_file}')

    def reproject_and_mask(self,
                           dtype=gdal.GDT_Int32,
                           no_data=None,
                           attribute_name='OBJECTID',
                           attribute_ids=None):
        """

        Parameters
        ----------
        attribute_ids : list of ints
            list of attribute ID values to select (Default value = None)
        dtype : gdal.datatype
            the datatype to write (Default value = gdal.GDT_Int32)
        no_data : str
            no_data value to use (Default value = None)
        attribute_name : str
            field in the shapefile to trace (Default value = 'OBJECTID')

        Returns
        -------
        str
            path (virtual mem) to the reprojected full_dim_mask

        """
        if attribute_ids is None:
            attribute_ids = [1]
        if no_data is None:
            no_data = self.no_data
        geom_ref = self.ds_ref.GetGeoTransform()
        tif_path = f'/vsimem/{self.shapefile_name}.tif'
        target_ds = gdal.GetDriverByName('GTiff').Create(
            tif_path, self.ds_ref.RasterXSize, self.ds_ref.RasterYSize, 1,
            dtype)
        target_ds.SetProjection(self.ds_ref.GetProjection())
        target_ds.SetGeoTransform(geom_ref)
        target_ds.GetRasterBand(1).SetNoDataValue(no_data)
        target_ds.GetRasterBand(1).Fill(no_data)
        # shapefile
        shp_source = ogr.Open(self.full_shapefile_path)
        shp_layer = shp_source.GetLayer()
        # TODO: How to detect if the shape geometries extend beyond our reference bounds?
        # Filter by the shapefile attribute IDs we want
        shp_layer.SetAttributeFilter(
            f'{attribute_name} in ({",".join([str(i) for i in attribute_ids])})'
        )
        # Rasterize layer
        rtn_code = gdal.RasterizeLayer(target_ds, [1],
                                       shp_layer,
                                       burn_values=[1])
        if rtn_code == 0:
            target_ds.FlushCache()
            logging.info(
                f'reprojected shapefile from {str(shp_layer.GetSpatialRef()).replace(chr(10), "")} '
                f'with extents {shp_layer.GetExtent()} '
                f'to {self.ds_ref.GetProjectionRef()} with transform {self.ds_ref.GetGeoTransform()}'
            )
        else:
            msg = f'error rasterizing layer: {shp_layer}, gdal returned non-zero value: {rtn_code}'
            logging.exception(msg)
            raise Exception(msg)
        self.subset_mask = SubsetMask(tif_path)
        return tif_path

    def rasterize_shapefile_to_disk(self,
                                    out_dir=None,
                                    out_name=None,
                                    padding=(0, 0, 0, 0),
                                    attribute_name='OBJECTID',
                                    attribute_ids=None):
        """rasterize a shapefile to disk in the projection and extents of the reference dataset

        Parameters
        ----------
        out_dir : str
            directory to write outputs (Default value = None)
        out_name : str
            filename for outputs (Default value = None)
        padding : tuple
            optional padding to add 0's around full_dim_mask (Default value = (0,0,0,0))
        attribute_name : str
            optional name of shapefile attribute to select on (Default value = 'OBJECTID')
        attribute_ids : list
            optional list of attribute ids in shapefile to select for full_dim_mask (Default value = None)

        Returns
        -------
        ndarray
            3d array with no_data to extents, 0 in bounding box, 1 in full_dim_mask region

        """
        if attribute_ids is None:
            attribute_ids = [1]
        if out_name is None:
            out_name = f'{self.shapefile_name}.tif'
        if out_dir is None:
            out_dir = self.output_path
        self.reproject_and_mask(attribute_ids=attribute_ids,
                                attribute_name=attribute_name)
        self.subset_mask.add_bbox_to_mask(padding=padding)
        self.subset_mask.write_mask_to_tif(
            filename=os.path.join(out_dir, out_name))
        self.subset_mask.write_bbox(os.path.join(out_dir, 'bbox.txt'))
        return self.subset_mask.mask_array