예제 #1
0
def write_array_to_file(dst_filename: PathLikeOrStr,
                        a: MaybeSequence[np.ndarray],
                        gdal_dtype=None) -> gdal.Dataset:
    driver_name = GetOutputDriverFor(dst_filename, is_raster=True)
    driver = gdal.GetDriverByName(driver_name)
    a_shape = a[0].shape
    if len(a_shape) == 1:
        # 2d array, singleband raster
        a = [a]
        bands_count = 1
    elif len(a_shape) == 2:
        # 3d array, multiband raster
        bands_count = a.shape[0]
    else:
        raise Exception('Array should have 2 or 3 dimensions')
    y_size, x_size = a[0].shape

    if gdal_dtype is None:
        np_dtype = a[0].dtype
        gdal_dtype = gdal_array.flip_code(np_dtype)
    ds = driver.Create(dst_filename, x_size, y_size, bands_count, gdal_dtype)
    if ds is None:
        raise Exception(f'failed to create: {dst_filename}')

    for bnd_num in range(bands_count):
        bnd = ds.GetRasterBand(bnd_num + 1)
        if gdal_array.BandWriteArray(bnd, a[bnd_num], xoff=0, yoff=0) != 0:
            raise Exception('I/O error')

    return ds
예제 #2
0
def doit(src_filename, pct_filename, dst_filename=None, driver=None):

    # =============================================================================
    # Get the PCT.
    # =============================================================================

    ct = get_color_table(pct_filename)
    if pct_filename is not None and ct is None:
        print('No color table on file ', pct_filename)
        return None, 1

    # =============================================================================
    # Create a MEM clone of the source file.
    # =============================================================================

    src_ds = open_ds(src_filename)

    mem_ds = gdal.GetDriverByName('MEM').CreateCopy('mem', src_ds)

    # =============================================================================
    # Assign the color table in memory.
    # =============================================================================

    mem_ds.GetRasterBand(1).SetRasterColorTable(ct)
    mem_ds.GetRasterBand(1).SetRasterColorInterpretation(gdal.GCI_PaletteIndex)

    # =============================================================================
    # Write the dataset to the output file.
    # =============================================================================

    if not driver:
        driver = GetOutputDriverFor(dst_filename)

    dst_driver = gdal.GetDriverByName(driver)
    if dst_driver is None:
        print('"%s" driver not registered.' % driver)
        return None, 1

    if driver.upper() == 'MEM':
        out_ds = mem_ds
    else:
        out_ds = dst_driver.CreateCopy(dst_filename or '', mem_ds)

    mem_ds = None
    src_ds = None

    return out_ds, 0
예제 #3
0
def gdal_sieve(src_filename: Optional[str] = None,
               dst_filename: PathLikeOrStr = None,
               driver_name: Optional[str] = None,
               mask: str = 'default',
               threshold: int = 2,
               connectedness: int = 4,
               quiet: bool = False):
    # =============================================================================
    # 	Verify we have next gen bindings with the sievefilter method.
    # =============================================================================
    try:
        gdal.SieveFilter
    except AttributeError:
        print('')
        print(
            'gdal.SieveFilter() not available.  You are likely using "old gen"'
        )
        print('bindings or an older version of the next gen bindings.')
        print('')
        return 1

    # =============================================================================
    # Open source file
    # =============================================================================

    if dst_filename is None:
        src_ds = gdal.Open(src_filename, gdal.GA_Update)
    else:
        src_ds = gdal.Open(src_filename, gdal.GA_ReadOnly)

    if src_ds is None:
        print('Unable to open %s ' % src_filename)
        return 1

    srcband = src_ds.GetRasterBand(1)

    if mask == 'default':
        maskband = srcband.GetMaskBand()
    elif mask == 'none':
        maskband = None
    else:
        mask_ds = gdal.Open(mask)
        maskband = mask_ds.GetRasterBand(1)

    # =============================================================================
    #       Create output file if one is specified.
    # =============================================================================

    if dst_filename is not None:
        if driver_name is None:
            driver_name = GetOutputDriverFor(dst_filename)

        drv = gdal.GetDriverByName(driver_name)
        dst_ds = drv.Create(dst_filename, src_ds.RasterXSize,
                            src_ds.RasterYSize, 1, srcband.DataType)
        wkt = src_ds.GetProjection()
        if wkt != '':
            dst_ds.SetProjection(wkt)
        gt = src_ds.GetGeoTransform(can_return_null=True)
        if gt is not None:
            dst_ds.SetGeoTransform(gt)

        dstband = dst_ds.GetRasterBand(1)
        dstband.SetNoDataValue(srcband.GetNoDataValue())
    else:
        dstband = srcband

    # =============================================================================
    # Invoke algorithm.
    # =============================================================================

    if quiet:
        prog_func = None
    else:
        prog_func = gdal.TermProgress_nocb

    result = gdal.SieveFilter(srcband,
                              maskband,
                              dstband,
                              threshold,
                              connectedness,
                              callback=prog_func)

    src_ds = None
    dst_ds = None
    mask_ds = None

    return result
예제 #4
0
def gdal_proximity(
    src_filename: Optional[str] = None,
    src_band_n: int = 1,
    dst_filename: Optional[str] = None,
    dst_band_n: int = 1,
    driver_name: Optional[str] = None,
    creation_type: str = 'Float32',
    creation_options: Optional[Sequence[str]] = None,
    alg_options: Optional[Sequence[str]] = None,
    quiet: bool = False):

    # =============================================================================
    #    Open source file
    # =============================================================================
    creation_options = creation_options or []
    alg_options = alg_options or []
    src_ds = gdal.Open(src_filename)

    if src_ds is None:
        print('Unable to open %s' % src_filename)
        return 1

    srcband = src_ds.GetRasterBand(src_band_n)

    # =============================================================================
    #       Try opening the destination file as an existing file.
    # =============================================================================

    try:
        driver_name = gdal.IdentifyDriver(dst_filename)
        if driver_name is not None:
            dst_ds = gdal.Open(dst_filename, gdal.GA_Update)
            dstband = dst_ds.GetRasterBand(dst_band_n)
        else:
            dst_ds = None
    except:
        dst_ds = None

    # =============================================================================
    #     Create output file.
    # =============================================================================
    if dst_ds is None:
        if driver_name is None:
            driver_name = GetOutputDriverFor(dst_filename)

        drv = gdal.GetDriverByName(driver_name)
        dst_ds = drv.Create(dst_filename,
                            src_ds.RasterXSize, src_ds.RasterYSize, 1,
                            gdal.GetDataTypeByName(creation_type), creation_options)

        dst_ds.SetGeoTransform(src_ds.GetGeoTransform())
        dst_ds.SetProjection(src_ds.GetProjectionRef())

        dstband = dst_ds.GetRasterBand(1)

    # =============================================================================
    #    Invoke algorithm.
    # =============================================================================

    if quiet:
        prog_func = None
    else:
        prog_func = gdal.TermProgress_nocb

    gdal.ComputeProximity(srcband, dstband, alg_options,
                          callback=prog_func)

    srcband = None
    dstband = None
    src_ds = None
    dst_ds = None
예제 #5
0
def pct2rgb(src_filename: PathLikeOrStr, pct_filename: Optional[PathLikeOrStr], dst_filename: PathLikeOrStr,
            band_number: int = 1, out_bands: int = 3, driver_name: Optional[str] = None):
    # Open source file
    src_ds = open_ds(src_filename)
    if src_ds is None:
        raise Exception(f'Unable to open {src_filename} ')

    src_band = src_ds.GetRasterBand(band_number)

    # ----------------------------------------------------------------------------
    # Ensure we recognise the driver.

    if driver_name is None:
        driver_name = GetOutputDriverFor(dst_filename)

    dst_driver = gdal.GetDriverByName(driver_name)
    if dst_driver is None:
        raise Exception(f'"{driver_name}" driver not registered.')

    # ----------------------------------------------------------------------------
    # Build color table.

    if pct_filename is not None:
        pal = get_color_palette(pct_filename)
        if pal.has_percents():
            min_val = src_band.GetMinimum()
            max_val = src_band.GetMinimum()
            pal.apply_percent(min_val, max_val)
        ct = get_color_table(pal)
    else:
        ct = src_band.GetRasterColorTable()

    ct_size = ct.GetCount()
    lookup = [np.arange(ct_size),
              np.arange(ct_size),
              np.arange(ct_size),
              np.ones(ct_size) * 255]

    if ct is not None:
        for i in range(ct_size):
            entry = ct.GetColorEntry(i)
            for c in range(4):
                lookup[c][i] = entry[c]

    # ----------------------------------------------------------------------------
    # Create the working file.

    if driver_name.lower() == 'gtiff':
        tif_filename = dst_filename
    else:
        tif_filename = 'temp.tif'

    gtiff_driver = gdal.GetDriverByName('GTiff')

    tif_ds = gtiff_driver.Create(tif_filename, src_ds.RasterXSize, src_ds.RasterYSize, out_bands)

    # ----------------------------------------------------------------------------
    # We should copy projection information and so forth at this point.

    tif_ds.SetProjection(src_ds.GetProjection())
    tif_ds.SetGeoTransform(src_ds.GetGeoTransform())
    if src_ds.GetGCPCount() > 0:
        tif_ds.SetGCPs(src_ds.GetGCPs(), src_ds.GetGCPProjection())

    # ----------------------------------------------------------------------------
    # Do the processing one scanline at a time.

    progress(0.0)
    for iY in range(src_ds.RasterYSize):
        src_data = src_band.ReadAsArray(0, iY, src_ds.RasterXSize, 1)

        for iBand in range(out_bands):
            band_lookup = lookup[iBand]

            dst_data = np.take(band_lookup, src_data)
            tif_ds.GetRasterBand(iBand + 1).WriteArray(dst_data, 0, iY)

        progress((iY + 1.0) / src_ds.RasterYSize)

    # ----------------------------------------------------------------------------
    # Translate intermediate file to output format if desired format is not TIFF.

    if tif_filename == dst_filename:
        dst_ds = tif_ds
    else:
        dst_ds = dst_driver.CreateCopy(dst_filename or '', tif_ds)
        tif_ds = None
        gtiff_driver.Delete(tif_filename)

    return dst_ds
예제 #6
0
def gdal_pansharpen(argv: Optional[Sequence[str]] = None,
                    pan_name: Optional[str] = None,
                    spectral_names: Optional[Sequence[str]] = None,
                    spectral_ds: Optional[List[gdal.Dataset]] = None,
                    spectral_bands: Optional[List[gdal.Band]] = None,
                    band_nums: Optional[Sequence[int]] = None,
                    weights: Optional[Sequence[float]] = None,
                    dst_filename: Optional[str] = None,
                    driver_name: Optional[str] = None,
                    creation_options: Optional[Sequence[str]] = None,
                    resampling: Optional[str] = None,
                    spat_adjust: Optional[str] = None,
                    num_threads: Optional[Union[int, str]] = None,
                    bitdepth: Optional[Union[int, str]] = None,
                    nodata_value: Optional[Union[Real, str]] = None,
                    verbose_vrt: bool = False,
                    progress_callback: Optional = gdal.TermProgress_nocb):
    if argv:
        # this is here for backwards compatibility
        return main(argv)

    spectral_names = spectral_names or []
    spectral_ds = spectral_ds or []
    spectral_bands = spectral_bands or []
    band_nums = band_nums or []
    weights = weights or []
    creation_options = creation_options or []

    if spectral_names:
        parse_spectral_names(spectral_names=spectral_names,
                             spectral_ds=spectral_ds,
                             spectral_bands=spectral_bands)

    if pan_name is None or not spectral_bands:
        return 1

    pan_ds = gdal.Open(pan_name)
    if pan_ds is None:
        return 1

    if driver_name is None:
        driver_name = GetOutputDriverFor(dst_filename)

    if not band_nums:
        band_nums = [j + 1 for j in range(len(spectral_bands))]
    else:
        for band in band_nums:
            if band < 0 or band > len(spectral_bands):
                print('Invalid band number in -b: %d' % band)
                return 1

    if weights and len(weights) != len(spectral_bands):
        print(
            'There must be as many -w values specified as input spectral bands'
        )
        return 1

    vrt_xml = """<VRTDataset subClass="VRTPansharpenedDataset">\n"""
    if band_nums != [j + 1 for j in range(len(spectral_bands))]:
        for i, band in enumerate(band_nums):
            sband = spectral_bands[band - 1]
            datatype = gdal.GetDataTypeName(sband.DataType)
            colorname = gdal.GetColorInterpretationName(
                sband.GetColorInterpretation())
            vrt_xml += """  <VRTRasterBand dataType="%s" band="%d" subClass="VRTPansharpenedRasterBand">
      <ColorInterp>%s</ColorInterp>
  </VRTRasterBand>\n""" % (datatype, i + 1, colorname)

    vrt_xml += """  <PansharpeningOptions>\n"""

    if weights:
        vrt_xml += """      <AlgorithmOptions>\n"""
        vrt_xml += """        <Weights>"""
        for i, weight in enumerate(weights):
            if i > 0:
                vrt_xml += ","
            vrt_xml += "%.16g" % weight
        vrt_xml += "</Weights>\n"
        vrt_xml += """      </AlgorithmOptions>\n"""

    if resampling is not None:
        vrt_xml += f'      <Resampling>{resampling}</Resampling>\n'

    if num_threads is not None:
        vrt_xml += f'      <NumThreads>{num_threads}</NumThreads>\n'

    if bitdepth is not None:
        vrt_xml += f'      <BitDepth>{bitdepth}</BitDepth>\n'

    if nodata_value is not None:
        vrt_xml += f'      <NoData>{nodata_value}</NoData>\n'

    if spat_adjust is not None:
        vrt_xml += f'      <SpatialExtentAdjustment>{spat_adjust}</SpatialExtentAdjustment>\n'

    pan_relative = '0'
    if driver_name.upper() == 'VRT':
        if not os.path.isabs(pan_name):
            pan_relative = '1'
            pan_name = os.path.relpath(pan_name, os.path.dirname(dst_filename))

    vrt_xml += """    <PanchroBand>
      <SourceFilename relativeToVRT="%s">%s</SourceFilename>
      <SourceBand>1</SourceBand>
    </PanchroBand>\n""" % (pan_relative, pan_name)

    for i, sband in enumerate(spectral_bands):
        dstband = ''
        for j, band in enumerate(band_nums):
            if i + 1 == band:
                dstband = ' dstBand="%d"' % (j + 1)
                break

        ms_relative = '0'
        ms_name = spectral_ds[i].GetDescription()
        if driver_name.upper() == 'VRT':
            if not os.path.isabs(ms_name):
                ms_relative = '1'
                ms_name = os.path.relpath(ms_name,
                                          os.path.dirname(dst_filename))

        vrt_xml += """    <SpectralBand%s>
      <SourceFilename relativeToVRT="%s">%s</SourceFilename>
      <SourceBand>%d</SourceBand>
    </SpectralBand>\n""" % (dstband, ms_relative, ms_name, sband.GetBand())

    vrt_xml += """  </PansharpeningOptions>\n"""
    vrt_xml += """</VRTDataset>\n"""

    if driver_name.upper() == 'VRT':
        f = gdal.VSIFOpenL(dst_filename, 'wb')
        if f is None:
            print('Cannot create %s' % dst_filename)
            return 1
        gdal.VSIFWriteL(vrt_xml, 1, len(vrt_xml), f)
        gdal.VSIFCloseL(f)
        if verbose_vrt:
            vrt_ds = gdal.Open(dst_filename, gdal.GA_Update)
            vrt_ds.SetMetadata(vrt_ds.GetMetadata())
        else:
            vrt_ds = gdal.Open(dst_filename)
        if vrt_ds is None:
            return 1

        return 0

    vrt_ds = gdal.Open(vrt_xml)
    out_ds = gdal.GetDriverByName(driver_name).CreateCopy(
        dst_filename, vrt_ds, 0, creation_options, callback=progress_callback)
    if out_ds is None:
        return 1
    return 0
예제 #7
0
def process(argv, progress=None, progress_arg=None):

    if not argv:
        return Usage()

    dst_filename = None
    output_format = None
    src_datasets = []
    overwrite_ds = False
    overwrite_layer = False
    update = False
    append = False
    single_layer = False
    layer_name_template = None
    skip_failures = False
    src_geom_types = []
    field_strategy = None
    src_layer_field_name = None
    src_layer_field_content = None
    a_srs = None
    s_srs = None
    t_srs = None
    dsco = []
    lco = []

    i = 0
    while i < len(argv):
        arg = argv[i]
        if (arg == '-f' or arg == '-of') and i + 1 < len(argv):
            i = i + 1
            output_format = argv[i]
        elif arg == '-o' and i + 1 < len(argv):
            i = i + 1
            dst_filename = argv[i]
        elif arg == '-progress':
            progress = ogr.TermProgress_nocb
            progress_arg = None
        elif arg == '-q' or arg == '-quiet':
            pass
        elif arg[0:5] == '-skip':
            skip_failures = True
        elif arg == '-update':
            update = True
        elif arg == '-overwrite_ds':
            overwrite_ds = True
        elif arg == '-overwrite_layer':
            overwrite_layer = True
            update = True
        elif arg == '-append':
            append = True
            update = True
        elif arg == '-single':
            single_layer = True
        elif arg == '-a_srs' and i + 1 < len(argv):
            i = i + 1
            a_srs = argv[i]
        elif arg == '-s_srs' and i + 1 < len(argv):
            i = i + 1
            s_srs = argv[i]
        elif arg == '-t_srs' and i + 1 < len(argv):
            i = i + 1
            t_srs = argv[i]
        elif arg == '-nln' and i + 1 < len(argv):
            i = i + 1
            layer_name_template = argv[i]
        elif arg == '-field_strategy' and i + 1 < len(argv):
            i = i + 1
            field_strategy = argv[i]
        elif arg == '-src_layer_field_name' and i + 1 < len(argv):
            i = i + 1
            src_layer_field_name = argv[i]
        elif arg == '-src_layer_field_content' and i + 1 < len(argv):
            i = i + 1
            src_layer_field_content = argv[i]
        elif arg == '-dsco' and i + 1 < len(argv):
            i = i + 1
            dsco.append(argv[i])
        elif arg == '-lco' and i + 1 < len(argv):
            i = i + 1
            lco.append(argv[i])
        elif arg == '-src_geom_type' and i + 1 < len(argv):
            i = i + 1
            src_geom_type_names = argv[i].split(',')
            for src_geom_type_name in src_geom_type_names:
                src_geom_type = _GetGeomType(src_geom_type_name)
                if src_geom_type is None:
                    print('ERROR: Unrecognized geometry type: %s' %
                          src_geom_type_name)
                    return 1
                src_geom_types.append(src_geom_type)
        elif arg[0] == '-':
            print('ERROR: Unrecognized argument : %s' % arg)
            return Usage()
        else:
            if '*' in arg:
                src_datasets += glob.glob(arg)
            else:
                src_datasets.append(arg)
        i = i + 1

    if dst_filename is None:
        print('Missing -o')
        return 1

    if update:
        if output_format is not None:
            print('ERROR: -f incompatible with -update')
            return 1
        if dsco:
            print('ERROR: -dsco incompatible with -update')
            return 1
        output_format = ''
    else:
        if output_format is None:
            output_format = GetOutputDriverFor(dst_filename, is_raster=False)

    if src_layer_field_content is None:
        src_layer_field_content = '{AUTO_NAME}'
    elif src_layer_field_name is None:
        src_layer_field_name = 'source_ds_lyr'

    if not single_layer and output_format == 'ESRI Shapefile' and \
       dst_filename.lower().endswith('.shp'):
        print('ERROR: Non-single layer mode incompatible with non-directory '
              'shapefile output')
        return 1

    if not src_datasets:
        print('ERROR: No source datasets')
        return 1

    if layer_name_template is None:
        if single_layer:
            layer_name_template = 'merged'
        else:
            layer_name_template = '{AUTO_NAME}'

    vrt_filename = None
    if not EQUAL(output_format, 'VRT'):
        dst_ds = gdal.OpenEx(dst_filename, gdal.OF_VECTOR | gdal.OF_UPDATE)
        if dst_ds is not None:
            if not update and not overwrite_ds:
                print('ERROR: Destination dataset already exists, ' +
                      'but -update nor -overwrite_ds are specified')
                return 1
            if overwrite_ds:
                drv = dst_ds.GetDriver()
                dst_ds = None
                if drv.GetDescription() == 'OGR_VRT':
                    # We don't want to destroy the sources of the VRT
                    gdal.Unlink(dst_filename)
                else:
                    drv.Delete(dst_filename)
        elif update:
            print('ERROR: Destination dataset does not exist')
            return 1
        if dst_ds is None:
            drv = gdal.GetDriverByName(output_format)
            if drv is None:
                print('ERROR: Invalid driver: %s' % output_format)
                return 1
            dst_ds = drv.Create(dst_filename, 0, 0, 0, gdal.GDT_Unknown, dsco)
            if dst_ds is None:
                return 1

        vrt_filename = '/vsimem/_ogrmerge_.vrt'
    else:
        if gdal.VSIStatL(dst_filename) and not overwrite_ds:
            print('ERROR: Destination dataset already exists, ' +
                  'but -overwrite_ds are specified')
            return 1
        vrt_filename = dst_filename

    f = gdal.VSIFOpenL(vrt_filename, 'wb')
    if f is None:
        print('ERROR: Cannot create %s' % vrt_filename)
        return 1

    writer = XMLWriter(f)
    writer.open_element('OGRVRTDataSource')

    if single_layer:

        ogr_vrt_union_layer_written = False

        for src_ds_idx, src_dsname in enumerate(src_datasets):
            src_ds = ogr.Open(src_dsname)
            if src_ds is None:
                print('ERROR: Cannot open %s' % src_dsname)
                if skip_failures:
                    continue
                gdal.VSIFCloseL(f)
                gdal.Unlink(vrt_filename)
                return 1
            for src_lyr_idx, src_lyr in enumerate(src_ds):
                if src_geom_types:
                    gt = ogr.GT_Flatten(src_lyr.GetGeomType())
                    if gt not in src_geom_types:
                        continue

                if not ogr_vrt_union_layer_written:
                    ogr_vrt_union_layer_written = True
                    writer.open_element('OGRVRTUnionLayer',
                                        attrs={'name': layer_name_template})

                    if src_layer_field_name is not None:
                        writer.write_element_value('SourceLayerFieldName',
                                                   src_layer_field_name)

                    if field_strategy is not None:
                        writer.write_element_value('FieldStrategy',
                                                   field_strategy)

                layer_name = src_layer_field_content

                src_lyr_name = src_lyr.GetName()
                try:
                    src_lyr_name = src_lyr_name.decode('utf-8')
                except AttributeError:
                    pass

                basename = None
                if os.path.exists(src_dsname):
                    basename = os.path.basename(src_dsname)
                    if '.' in basename:
                        basename = '.'.join(basename.split(".")[0:-1])

                if basename == src_lyr_name:
                    layer_name = layer_name.replace('{AUTO_NAME}', basename)
                elif basename is None:
                    layer_name = layer_name.replace(
                        '{AUTO_NAME}',
                        'Dataset%d_%s' % (src_ds_idx, src_lyr_name))
                else:
                    layer_name = layer_name.replace(
                        '{AUTO_NAME}', basename + '_' + src_lyr_name)

                if basename is not None:
                    layer_name = layer_name.replace('{DS_BASENAME}', basename)
                else:
                    layer_name = layer_name.replace('{DS_BASENAME}',
                                                    src_dsname)
                layer_name = layer_name.replace('{DS_NAME}', '%s' % src_dsname)
                layer_name = layer_name.replace('{DS_INDEX}',
                                                '%d' % src_ds_idx)
                layer_name = layer_name.replace('{LAYER_NAME}', src_lyr_name)
                layer_name = layer_name.replace('{LAYER_INDEX}',
                                                '%d' % src_lyr_idx)

                if t_srs is not None:
                    writer.open_element('OGRVRTWarpedLayer')

                writer.open_element('OGRVRTLayer', attrs={'name': layer_name})
                attrs = {}
                if EQUAL(output_format, 'VRT') and \
                   os.path.exists(src_dsname) and \
                   not os.path.isabs(src_dsname) and \
                   '/' not in vrt_filename and \
                   '\\' not in vrt_filename:
                    attrs['relativeToVRT'] = '1'
                if single_layer:
                    attrs['shared'] = '1'
                writer.write_element_value('SrcDataSource',
                                           src_dsname,
                                           attrs=attrs)
                writer.write_element_value('SrcLayer', src_lyr.GetName())

                if a_srs is not None:
                    writer.write_element_value('LayerSRS', a_srs)

                writer.close_element('OGRVRTLayer')

                if t_srs is not None:
                    if s_srs is not None:
                        writer.write_element_value('SrcSRS', s_srs)

                    writer.write_element_value('TargetSRS', t_srs)

                    writer.close_element('OGRVRTWarpedLayer')

        if ogr_vrt_union_layer_written:
            writer.close_element('OGRVRTUnionLayer')

    else:

        for src_ds_idx, src_dsname in enumerate(src_datasets):
            src_ds = ogr.Open(src_dsname)
            if src_ds is None:
                print('ERROR: Cannot open %s' % src_dsname)
                if skip_failures:
                    continue
                gdal.VSIFCloseL(f)
                gdal.Unlink(vrt_filename)
                return 1
            for src_lyr_idx, src_lyr in enumerate(src_ds):
                if src_geom_types:
                    gt = ogr.GT_Flatten(src_lyr.GetGeomType())
                    if gt not in src_geom_types:
                        continue

                src_lyr_name = src_lyr.GetName()
                try:
                    src_lyr_name = src_lyr_name.decode('utf-8')
                except AttributeError:
                    pass

                layer_name = layer_name_template
                basename = None
                if os.path.exists(src_dsname):
                    basename = os.path.basename(src_dsname)
                    if '.' in basename:
                        basename = '.'.join(basename.split(".")[0:-1])

                if basename == src_lyr_name:
                    layer_name = layer_name.replace('{AUTO_NAME}', basename)
                elif basename is None:
                    layer_name = layer_name.replace(
                        '{AUTO_NAME}',
                        'Dataset%d_%s' % (src_ds_idx, src_lyr_name))
                else:
                    layer_name = layer_name.replace(
                        '{AUTO_NAME}', basename + '_' + src_lyr_name)

                if basename is not None:
                    layer_name = layer_name.replace('{DS_BASENAME}', basename)
                elif '{DS_BASENAME}' in layer_name:
                    if skip_failures:
                        if '{DS_INDEX}' not in layer_name:
                            layer_name = layer_name.replace(
                                '{DS_BASENAME}', 'Dataset%d' % src_ds_idx)
                    else:
                        print('ERROR: Layer name template %s '
                              'includes {DS_BASENAME} '
                              'but %s is not a file' %
                              (layer_name_template, src_dsname))

                        gdal.VSIFCloseL(f)
                        gdal.Unlink(vrt_filename)
                        return 1
                layer_name = layer_name.replace('{DS_NAME}', '%s' % src_dsname)
                layer_name = layer_name.replace('{DS_INDEX}',
                                                '%d' % src_ds_idx)
                layer_name = layer_name.replace('{LAYER_NAME}', src_lyr_name)
                layer_name = layer_name.replace('{LAYER_INDEX}',
                                                '%d' % src_lyr_idx)

                if t_srs is not None:
                    writer.open_element('OGRVRTWarpedLayer')

                writer.open_element('OGRVRTLayer', attrs={'name': layer_name})
                attrs = {}
                if EQUAL(output_format, 'VRT') and \
                   os.path.exists(src_dsname) and \
                   not os.path.isabs(src_dsname) and \
                   '/' not in vrt_filename and \
                   '\\' not in vrt_filename:
                    attrs['relativeToVRT'] = '1'
                if single_layer:
                    attrs['shared'] = '1'
                writer.write_element_value('SrcDataSource',
                                           src_dsname,
                                           attrs=attrs)
                writer.write_element_value('SrcLayer', src_lyr_name)

                if a_srs is not None:
                    writer.write_element_value('LayerSRS', a_srs)

                writer.close_element('OGRVRTLayer')

                if t_srs is not None:
                    if s_srs is not None:
                        writer.write_element_value('SrcSRS', s_srs)

                    writer.write_element_value('TargetSRS', t_srs)

                    writer.close_element('OGRVRTWarpedLayer')

    writer.close_element('OGRVRTDataSource')

    gdal.VSIFCloseL(f)

    ret = 0
    if not EQUAL(output_format, 'VRT'):
        accessMode = None
        if append:
            accessMode = 'append'
        elif overwrite_layer:
            accessMode = 'overwrite'
        ret = gdal.VectorTranslate(dst_ds,
                                   vrt_filename,
                                   accessMode=accessMode,
                                   layerCreationOptions=lco,
                                   skipFailures=skip_failures,
                                   callback=progress,
                                   callback_data=progress_arg)
        if ret == 1:
            ret = 0
        else:
            ret = 1
        gdal.Unlink(vrt_filename)

    return ret
예제 #8
0
def main(argv):
    threshold = 2
    connectedness = 4
    quiet_flag = 0
    src_filename = None

    dst_filename = None
    frmt = None

    mask = 'default'

    argv = gdal.GeneralCmdLineProcessor(argv)
    if argv is None:
        return 0

    # Parse command line arguments.
    i = 1
    while i < len(argv):
        arg = argv[i]

        if arg == '-of' or arg == '-f':
            i = i + 1
            frmt = argv[i]

        elif arg == '-4':
            connectedness = 4

        elif arg == '-8':
            connectedness = 8

        elif arg == '-q' or arg == '-quiet':
            quiet_flag = 1

        elif arg == '-st':
            i = i + 1
            threshold = int(argv[i])

        elif arg == '-nomask':
            mask = 'none'

        elif arg == '-mask':
            i = i + 1
            mask = argv[i]

        elif arg == '-mask':
            i = i + 1
            mask = argv[i]

        elif arg[:2] == '-h':
            return Usage()

        elif src_filename is None:
            src_filename = argv[i]

        elif dst_filename is None:
            dst_filename = argv[i]

        else:
            return Usage()

        i = i + 1

    if src_filename is None:
        return Usage()

    # =============================================================================
    # 	Verify we have next gen bindings with the sievefilter method.
    # =============================================================================
    try:
        gdal.SieveFilter
    except AttributeError:
        print('')
        print(
            'gdal.SieveFilter() not available.  You are likely using "old gen"'
        )
        print('bindings or an older version of the next gen bindings.')
        print('')
        return 1

    # =============================================================================
    # Open source file
    # =============================================================================

    if dst_filename is None:
        src_ds = gdal.Open(src_filename, gdal.GA_Update)
    else:
        src_ds = gdal.Open(src_filename, gdal.GA_ReadOnly)

    if src_ds is None:
        print('Unable to open %s ' % src_filename)
        return 1

    srcband = src_ds.GetRasterBand(1)

    if mask == 'default':
        maskband = srcband.GetMaskBand()
    elif mask == 'none':
        maskband = None
    else:
        mask_ds = gdal.Open(mask)
        maskband = mask_ds.GetRasterBand(1)

    # =============================================================================
    #       Create output file if one is specified.
    # =============================================================================

    if dst_filename is not None:
        if frmt is None:
            frmt = GetOutputDriverFor(dst_filename)

        drv = gdal.GetDriverByName(frmt)
        dst_ds = drv.Create(dst_filename, src_ds.RasterXSize,
                            src_ds.RasterYSize, 1, srcband.DataType)
        wkt = src_ds.GetProjection()
        if wkt != '':
            dst_ds.SetProjection(wkt)
        gt = src_ds.GetGeoTransform(can_return_null=True)
        if gt is not None:
            dst_ds.SetGeoTransform(gt)

        dstband = dst_ds.GetRasterBand(1)
    else:
        dstband = srcband

    # =============================================================================
    # Invoke algorithm.
    # =============================================================================

    if quiet_flag:
        prog_func = None
    else:
        prog_func = gdal.TermProgress_nocb

    result = gdal.SieveFilter(srcband,
                              maskband,
                              dstband,
                              threshold,
                              connectedness,
                              callback=prog_func)

    src_ds = None
    dst_ds = None
    mask_ds = None

    return result
예제 #9
0
def main(argv=sys.argv):
    i = 1
    output_format = None
    in_filename = None
    out_filename = None
    ovr_level = None
    while i < len(argv):
        if argv[i] == "-f":
            output_format = argv[i + 1]
            i = i + 1
        elif argv[i] == "-ovr":
            ovr_level = int(argv[i + 1])
            i = i + 1
        elif argv[i][0] == '-':
            return Usage()
        elif in_filename is None:
            in_filename = argv[i]
        elif out_filename is None:
            out_filename = argv[i]
        else:
            return Usage()

        i = i + 1

    if out_filename is None:
        return Usage()
    if output_format is None:
        output_format = GetOutputDriverFor(out_filename, is_raster=False)

    src_ds = gdal.Open(in_filename)
    out_ds = gdal.GetDriverByName(output_format).Create(
        out_filename, 0, 0, 0, gdal.GDT_Unknown)
    first_band = src_ds.GetRasterBand(1)
    main_gt = src_ds.GetGeoTransform()

    for i in ([ovr_level] if ovr_level is not None else
              range(1 + first_band.GetOverviewCount())):
        src_band = first_band if i == 0 else first_band.GetOverview(i - 1)
        out_lyr = out_ds.CreateLayer('main_image' if i == 0 else
                                     ('overview_%d' % i),
                                     geom_type=ogr.wkbPolygon,
                                     srs=src_ds.GetSpatialRef())
        blockxsize, blockysize = src_band.GetBlockSize()
        nxblocks = (src_band.XSize + blockxsize - 1) // blockxsize
        nyblocks = (src_band.YSize + blockysize - 1) // blockysize
        gt = [
            main_gt[0], main_gt[1] * first_band.XSize / src_band.XSize, 0,
            main_gt[3], 0, main_gt[5] * first_band.YSize / src_band.YSize
        ]
        for y in range(nyblocks):
            ymax = gt[3] + y * blockysize * gt[5]
            ymin = ymax + blockysize * gt[5]
            for x in range(nxblocks):
                xmin = gt[0] + x * blockxsize * gt[1]
                xmax = xmin + blockxsize * gt[1]
                f = ogr.Feature(out_lyr.GetLayerDefn())
                wkt = 'POLYGON((%.18g %.18g,%.18g %.18g,%.18g %.18g,%.18g %.18g,%.18g %.18g))' % (
                    xmin, ymin, xmin, ymax, xmax, ymax, xmax, ymin, xmin, ymin)
                f.SetGeometryDirectly(ogr.CreateGeometryFromWkt(wkt))
                out_lyr.CreateFeature(f)
    out_ds = None
    return 0
예제 #10
0
def gdal_polygonize(src_filename: Optional[str] = None,
                    band_number: Union[int, str] = 1,
                    dst_filename: Optional[str] = None,
                    driver_name: Optional[str] = None,
                    dst_layername: Optional[str] = None,
                    dst_fieldname: Optional[str] = None,
                    quiet: bool = False,
                    mask: str = 'default',
                    options: Optional[list] = None,
                    connectedness8: bool = False):

    if isinstance(band_number, str) and not band_number.startswith('mask'):
        band_number = int(band_number)

    options = options or []

    if connectedness8:
        options.append('8CONNECTED=8')

    if driver_name is None:
        driver_name = GetOutputDriverFor(dst_filename, is_raster=False)

    if dst_layername is None:
        dst_layername = 'out'

    # =============================================================================
    # Open source file
    # =============================================================================

    src_ds = gdal.Open(src_filename)

    if src_ds is None:
        print('Unable to open %s' % src_filename)
        return 1

    if band_number == 'mask':
        srcband = src_ds.GetRasterBand(1).GetMaskBand()
        # Workaround the fact that most source bands have no dataset attached
        options.append('DATASET_FOR_GEOREF=' + src_filename)
    elif isinstance(band_number, str) and band_number.startswith('mask,'):
        srcband = src_ds.GetRasterBand(int(
            band_number[len('mask,'):])).GetMaskBand()
        # Workaround the fact that most source bands have no dataset attached
        options.append('DATASET_FOR_GEOREF=' + src_filename)
    else:
        srcband = src_ds.GetRasterBand(band_number)

    if mask == 'default':
        maskband = srcband.GetMaskBand()
    elif mask == 'none':
        maskband = None
    else:
        mask_ds = gdal.Open(mask)
        maskband = mask_ds.GetRasterBand(1)

    # =============================================================================
    #       Try opening the destination file as an existing file.
    # =============================================================================

    try:
        gdal.PushErrorHandler('CPLQuietErrorHandler')
        dst_ds = ogr.Open(dst_filename, update=1)
        gdal.PopErrorHandler()
    except:
        dst_ds = None

    # =============================================================================
    # 	Create output file.
    # =============================================================================
    if dst_ds is None:
        drv = ogr.GetDriverByName(driver_name)
        if not quiet:
            print('Creating output %s of format %s.' %
                  (dst_filename, driver_name))
        dst_ds = drv.CreateDataSource(dst_filename)

    # =============================================================================
    #       Find or create destination layer.
    # =============================================================================
    try:
        dst_layer = dst_ds.GetLayerByName(dst_layername)
    except:
        dst_layer = None

    dst_field: int = -1
    if dst_layer is None:

        srs = src_ds.GetSpatialRef()
        dst_layer = dst_ds.CreateLayer(dst_layername,
                                       geom_type=ogr.wkbPolygon,
                                       srs=srs)

        if dst_fieldname is None:
            dst_fieldname = 'DN'

        data_type = ogr.OFTInteger
        if srcband.DataType == gdal.GDT_Int64 or srcband.DataType == gdal.GDT_UInt64:
            data_type = ogr.OFTInteger64

        fd = ogr.FieldDefn(dst_fieldname, data_type)
        dst_layer.CreateField(fd)
        dst_field = 0
    else:
        if dst_fieldname is not None:
            dst_field = dst_layer.GetLayerDefn().GetFieldIndex(dst_fieldname)
            if dst_field < 0:
                print("Warning: cannot find field '%s' in layer '%s'" %
                      (dst_fieldname, dst_layername))

    # =============================================================================
    # Invoke algorithm.
    # =============================================================================

    if quiet:
        prog_func = None
    else:
        prog_func = gdal.TermProgress_nocb

    result = gdal.Polygonize(srcband,
                             maskband,
                             dst_layer,
                             dst_field,
                             options,
                             callback=prog_func)

    srcband = None
    src_ds = None
    dst_ds = None
    mask_ds = None

    return result
예제 #11
0
파일: rgb2pct.py 프로젝트: whatcoloris/gdal
def rgb2pct(src_filename: PathLikeOrStr,
            pct_filename: Optional[PathLikeOrStr] = None,
            dst_filename: Optional[PathLikeOrStr] = None,
            color_count: int = 256,
            driver_name: Optional[str] = None):
    # Open source file
    src_ds = open_ds(src_filename)
    if src_ds is None:
        raise Exception(f'Unable to open {src_filename}')

    if src_ds.RasterCount < 3:
        raise Exception(
            f'{src_filename} has {src_ds.RasterCount} band(s), need 3 for inputs red, green and blue.'
        )

    # Ensure we recognise the driver.
    if not driver_name:
        driver_name = GetOutputDriverFor(dst_filename)

    dst_driver = gdal.GetDriverByName(driver_name)
    if dst_driver is None:
        raise Exception(f'"{driver_name}" driver not registered.')

    # Generate palette
    if pct_filename is None:
        ct = gdal.ColorTable()
        err = gdal.ComputeMedianCutPCT(src_ds.GetRasterBand(1),
                                       src_ds.GetRasterBand(2),
                                       src_ds.GetRasterBand(3),
                                       color_count,
                                       ct,
                                       callback=gdal.TermProgress_nocb)
    else:
        ct = get_color_table(pct_filename)

    # Create the working file.  We have to use TIFF since there are few formats
    # that allow setting the color table after creation.

    if driver_name.lower() == 'gtiff':
        tif_filename = dst_filename
    else:
        import tempfile
        tif_filedesc, tif_filename = tempfile.mkstemp(suffix='.tif')

    gtiff_driver = gdal.GetDriverByName('GTiff')

    tif_ds = gtiff_driver.Create(tif_filename, src_ds.RasterXSize,
                                 src_ds.RasterYSize, 1)

    tif_ds.GetRasterBand(1).SetRasterColorTable(ct)

    # ----------------------------------------------------------------------------
    # We should copy projection information and so forth at this point.

    tif_ds.SetProjection(src_ds.GetProjection())
    tif_ds.SetGeoTransform(src_ds.GetGeoTransform())
    if src_ds.GetGCPCount() > 0:
        tif_ds.SetGCPs(src_ds.GetGCPs(), src_ds.GetGCPProjection())

    # ----------------------------------------------------------------------------
    # Actually transfer and dither the data.

    err = gdal.DitherRGB2PCT(src_ds.GetRasterBand(1),
                             src_ds.GetRasterBand(2),
                             src_ds.GetRasterBand(3),
                             tif_ds.GetRasterBand(1),
                             ct,
                             callback=gdal.TermProgress_nocb)
    if err != gdal.CE_None:
        raise Exception('DitherRGB2PCT failed')

    if tif_filename == dst_filename:
        dst_ds = tif_ds
    else:
        dst_ds = dst_driver.CreateCopy(dst_filename or '', tif_ds)
        tif_ds = None
        os.close(tif_filedesc)
        gtiff_driver.Delete(tif_filename)

    return dst_ds
예제 #12
0
def main(argv):
    frmt = None
    options = []
    quiet_flag = 0
    src_filename = None
    src_band_n = 1

    dst_filename = None
    dst_layername = None
    dst_fieldname = None
    dst_field = -1

    mask = 'default'

    argv = gdal.GeneralCmdLineProcessor(argv)
    if argv is None:
        return 0

    # Parse command line arguments.
    i = 1
    while i < len(argv):
        arg = argv[i]

        if arg == '-f' or arg == '-of':
            i = i + 1
            frmt = argv[i]

        elif arg == '-q' or arg == '-quiet':
            quiet_flag = 1

        elif arg == '-8':
            options.append('8CONNECTED=8')

        elif arg == '-nomask':
            mask = 'none'

        elif arg == '-mask':
            i = i + 1
            mask = argv[i]

        elif arg == '-b':
            i = i + 1
            if argv[i].startswith('mask'):
                src_band_n = argv[i]
            else:
                src_band_n = int(argv[i])

        elif src_filename is None:
            src_filename = argv[i]

        elif dst_filename is None:
            dst_filename = argv[i]

        elif dst_layername is None:
            dst_layername = argv[i]

        elif dst_fieldname is None:
            dst_fieldname = argv[i]

        else:
            return Usage()

        i = i + 1

    if src_filename is None or dst_filename is None:
        return Usage()

    if frmt is None:
        frmt = GetOutputDriverFor(dst_filename, is_raster=False)

    if dst_layername is None:
        dst_layername = 'out'

    # =============================================================================
    # 	Verify we have next gen bindings with the polygonize method.
    # =============================================================================
    try:
        gdal.Polygonize
    except AttributeError:
        print('')
        print(
            'gdal.Polygonize() not available.  You are likely using "old gen"')
        print('bindings or an older version of the next gen bindings.')
        print('')
        return 1

    # =============================================================================
    # Open source file
    # =============================================================================

    src_ds = gdal.Open(src_filename)

    if src_ds is None:
        print('Unable to open %s' % src_filename)
        return 1

    if src_band_n == 'mask':
        srcband = src_ds.GetRasterBand(1).GetMaskBand()
        # Workaround the fact that most source bands have no dataset attached
        options.append('DATASET_FOR_GEOREF=' + src_filename)
    elif isinstance(src_band_n, str) and src_band_n.startswith('mask,'):
        srcband = src_ds.GetRasterBand(int(
            src_band_n[len('mask,'):])).GetMaskBand()
        # Workaround the fact that most source bands have no dataset attached
        options.append('DATASET_FOR_GEOREF=' + src_filename)
    else:
        srcband = src_ds.GetRasterBand(src_band_n)

    if mask == 'default':
        maskband = srcband.GetMaskBand()
    elif mask == 'none':
        maskband = None
    else:
        mask_ds = gdal.Open(mask)
        maskband = mask_ds.GetRasterBand(1)

    # =============================================================================
    #       Try opening the destination file as an existing file.
    # =============================================================================

    try:
        gdal.PushErrorHandler('CPLQuietErrorHandler')
        dst_ds = ogr.Open(dst_filename, update=1)
        gdal.PopErrorHandler()
    except:
        dst_ds = None

    # =============================================================================
    # 	Create output file.
    # =============================================================================
    if dst_ds is None:
        drv = ogr.GetDriverByName(frmt)
        if not quiet_flag:
            print('Creating output %s of format %s.' % (dst_filename, frmt))
        dst_ds = drv.CreateDataSource(dst_filename)

    # =============================================================================
    #       Find or create destination layer.
    # =============================================================================
    try:
        dst_layer = dst_ds.GetLayerByName(dst_layername)
    except:
        dst_layer = None

    if dst_layer is None:

        srs = src_ds.GetSpatialRef()
        dst_layer = dst_ds.CreateLayer(dst_layername,
                                       geom_type=ogr.wkbPolygon,
                                       srs=srs)

        if dst_fieldname is None:
            dst_fieldname = 'DN'

        fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
        dst_layer.CreateField(fd)
        dst_field = 0
    else:
        if dst_fieldname is not None:
            dst_field = dst_layer.GetLayerDefn().GetFieldIndex(dst_fieldname)
            if dst_field < 0:
                print("Warning: cannot find field '%s' in layer '%s'" %
                      (dst_fieldname, dst_layername))

    # =============================================================================
    # Invoke algorithm.
    # =============================================================================

    if quiet_flag:
        prog_func = None
    else:
        prog_func = gdal.TermProgress_nocb

    result = gdal.Polygonize(srcband,
                             maskband,
                             dst_layer,
                             dst_field,
                             options,
                             callback=prog_func)

    srcband = None
    src_ds = None
    dst_ds = None
    mask_ds = None

    return result
def gdal_pansharpen(argv):

    argv = gdal.GeneralCmdLineProcessor(argv)
    if argv is None:
        return -1

    pan_name = None
    last_name = None
    spectral_ds = []
    spectral_bands = []
    out_name = None
    bands = []
    weights = []
    frmt = None
    creation_options = []
    callback = gdal.TermProgress_nocb
    resampling = None
    spat_adjust = None
    verbose_vrt = False
    num_threads = None
    bitdepth = None
    nodata = None

    i = 1
    argc = len(argv)
    while i < argc:
        if (argv[i] == '-of' or argv[i] == '-f') and i < len(argv) - 1:
            frmt = argv[i + 1]
            i = i + 1
        elif argv[i] == '-r' and i < len(argv) - 1:
            resampling = argv[i + 1]
            i = i + 1
        elif argv[i] == '-spat_adjust' and i < len(argv) - 1:
            spat_adjust = argv[i + 1]
            i = i + 1
        elif argv[i] == '-b' and i < len(argv) - 1:
            bands.append(int(argv[i + 1]))
            i = i + 1
        elif argv[i] == '-w' and i < len(argv) - 1:
            weights.append(float(argv[i + 1]))
            i = i + 1
        elif argv[i] == '-co' and i < len(argv) - 1:
            creation_options.append(argv[i + 1])
            i = i + 1
        elif argv[i] == '-threads' and i < len(argv) - 1:
            num_threads = argv[i + 1]
            i = i + 1
        elif argv[i] == '-bitdepth' and i < len(argv) - 1:
            bitdepth = argv[i + 1]
            i = i + 1
        elif argv[i] == '-nodata' and i < len(argv) - 1:
            nodata = argv[i + 1]
            i = i + 1
        elif argv[i] == '-q':
            callback = None
        elif argv[i] == '-verbose_vrt':
            verbose_vrt = True
        elif argv[i][0] == '-':
            sys.stderr.write('Unrecognized option : %s\n' % argv[i])
            return Usage()
        elif pan_name is None:
            pan_name = argv[i]
            pan_ds = gdal.Open(pan_name)
            if pan_ds is None:
                return 1
        else:
            if last_name is not None:
                pos = last_name.find(',band=')
                if pos > 0:
                    spectral_name = last_name[0:pos]
                    ds = gdal.Open(spectral_name)
                    if ds is None:
                        return 1
                    band_num = int(last_name[pos + len(',band='):])
                    band = ds.GetRasterBand(band_num)
                    spectral_ds.append(ds)
                    spectral_bands.append(band)
                else:
                    spectral_name = last_name
                    ds = gdal.Open(spectral_name)
                    if ds is None:
                        return 1
                    for j in range(ds.RasterCount):
                        spectral_ds.append(ds)
                        spectral_bands.append(ds.GetRasterBand(j + 1))

            last_name = argv[i]

        i = i + 1

    if pan_name is None or not spectral_bands:
        return Usage()
    out_name = last_name

    if frmt is None:
        frmt = GetOutputDriverFor(out_name)

    if not bands:
        bands = [j + 1 for j in range(len(spectral_bands))]
    else:
        for band in bands:
            if band < 0 or band > len(spectral_bands):
                print('Invalid band number in -b: %d' % band)
                return 1

    if weights and len(weights) != len(spectral_bands):
        print(
            'There must be as many -w values specified as input spectral bands'
        )
        return 1

    vrt_xml = """<VRTDataset subClass="VRTPansharpenedDataset">\n"""
    if bands != [j + 1 for j in range(len(spectral_bands))]:
        for i, band in enumerate(bands):
            sband = spectral_bands[band - 1]
            datatype = gdal.GetDataTypeName(sband.DataType)
            colorname = gdal.GetColorInterpretationName(
                sband.GetColorInterpretation())
            vrt_xml += """  <VRTRasterBand dataType="%s" band="%d" subClass="VRTPansharpenedRasterBand">
      <ColorInterp>%s</ColorInterp>
  </VRTRasterBand>\n""" % (datatype, i + 1, colorname)

    vrt_xml += """  <PansharpeningOptions>\n"""

    if weights:
        vrt_xml += """      <AlgorithmOptions>\n"""
        vrt_xml += """        <Weights>"""
        for i, weight in enumerate(weights):
            if i > 0:
                vrt_xml += ","
            vrt_xml += "%.16g" % weight
        vrt_xml += "</Weights>\n"
        vrt_xml += """      </AlgorithmOptions>\n"""

    if resampling is not None:
        vrt_xml += '      <Resampling>%s</Resampling>\n' % resampling

    if num_threads is not None:
        vrt_xml += '      <NumThreads>%s</NumThreads>\n' % num_threads

    if bitdepth is not None:
        vrt_xml += '      <BitDepth>%s</BitDepth>\n' % bitdepth

    if nodata is not None:
        vrt_xml += '      <NoData>%s</NoData>\n' % nodata

    if spat_adjust is not None:
        vrt_xml += '      <SpatialExtentAdjustment>%s</SpatialExtentAdjustment>\n' % spat_adjust

    pan_relative = '0'
    if frmt.upper() == 'VRT':
        if not os.path.isabs(pan_name):
            pan_relative = '1'
            pan_name = os.path.relpath(pan_name, os.path.dirname(out_name))

    vrt_xml += """    <PanchroBand>
      <SourceFilename relativeToVRT="%s">%s</SourceFilename>
      <SourceBand>1</SourceBand>
    </PanchroBand>\n""" % (pan_relative, pan_name)

    for i, sband in enumerate(spectral_bands):
        dstband = ''
        for j, band in enumerate(bands):
            if i + 1 == band:
                dstband = ' dstBand="%d"' % (j + 1)
                break

        ms_relative = '0'
        ms_name = spectral_ds[i].GetDescription()
        if frmt.upper() == 'VRT':
            if not os.path.isabs(ms_name):
                ms_relative = '1'
                ms_name = os.path.relpath(ms_name, os.path.dirname(out_name))

        vrt_xml += """    <SpectralBand%s>
      <SourceFilename relativeToVRT="%s">%s</SourceFilename>
      <SourceBand>%d</SourceBand>
    </SpectralBand>\n""" % (dstband, ms_relative, ms_name, sband.GetBand())

    vrt_xml += """  </PansharpeningOptions>\n"""
    vrt_xml += """</VRTDataset>\n"""

    if frmt.upper() == 'VRT':
        f = gdal.VSIFOpenL(out_name, 'wb')
        if f is None:
            print('Cannot create %s' % out_name)
            return 1
        gdal.VSIFWriteL(vrt_xml, 1, len(vrt_xml), f)
        gdal.VSIFCloseL(f)
        if verbose_vrt:
            vrt_ds = gdal.Open(out_name, gdal.GA_Update)
            vrt_ds.SetMetadata(vrt_ds.GetMetadata())
        else:
            vrt_ds = gdal.Open(out_name)
        if vrt_ds is None:
            return 1

        return 0

    vrt_ds = gdal.Open(vrt_xml)
    out_ds = gdal.GetDriverByName(frmt).CreateCopy(out_name,
                                                   vrt_ds,
                                                   0,
                                                   creation_options,
                                                   callback=callback)
    if out_ds is None:
        return 1
    return 0
def main(argv):
    frmt = None
    creation_options = []
    options = []
    src_filename = None
    src_band_n = 1
    dst_filename = None
    dst_band_n = 1
    creation_type = 'Float32'
    quiet_flag = 0

    argv = gdal.GeneralCmdLineProcessor(argv)
    if argv is None:
        return 0

    # Parse command line arguments.
    i = 1
    while i < len(argv):
        arg = argv[i]

        if arg == '-of' or arg == '-f':
            i = i + 1
            frmt = argv[i]

        elif arg == '-co':
            i = i + 1
            creation_options.append(argv[i])

        elif arg == '-ot':
            i = i + 1
            creation_type = argv[i]

        elif arg == '-maxdist':
            i = i + 1
            options.append('MAXDIST=' + argv[i])

        elif arg == '-values':
            i = i + 1
            options.append('VALUES=' + argv[i])

        elif arg == '-distunits':
            i = i + 1
            options.append('DISTUNITS=' + argv[i])

        elif arg == '-nodata':
            i = i + 1
            options.append('NODATA=' + argv[i])

        elif arg == '-use_input_nodata':
            i = i + 1
            options.append('USE_INPUT_NODATA=' + argv[i])

        elif arg == '-fixed-buf-val':
            i = i + 1
            options.append('FIXED_BUF_VAL=' + argv[i])

        elif arg == '-srcband':
            i = i + 1
            src_band_n = int(argv[i])

        elif arg == '-dstband':
            i = i + 1
            dst_band_n = int(argv[i])

        elif arg == '-q' or arg == '-quiet':
            quiet_flag = 1

        elif src_filename is None:
            src_filename = argv[i]

        elif dst_filename is None:
            dst_filename = argv[i]

        else:
            return Usage()

        i = i + 1

    if src_filename is None or dst_filename is None:
        return Usage()

    # =============================================================================
    #    Open source file
    # =============================================================================

    src_ds = gdal.Open(src_filename)

    if src_ds is None:
        print('Unable to open %s' % src_filename)
        return 1

    srcband = src_ds.GetRasterBand(src_band_n)

    # =============================================================================
    #       Try opening the destination file as an existing file.
    # =============================================================================

    try:
        driver = gdal.IdentifyDriver(dst_filename)
        if driver is not None:
            dst_ds = gdal.Open(dst_filename, gdal.GA_Update)
            dstband = dst_ds.GetRasterBand(dst_band_n)
        else:
            dst_ds = None
    except:
        dst_ds = None

    # =============================================================================
    #     Create output file.
    # =============================================================================
    if dst_ds is None:
        if frmt is None:
            frmt = GetOutputDriverFor(dst_filename)

        drv = gdal.GetDriverByName(frmt)
        dst_ds = drv.Create(dst_filename, src_ds.RasterXSize,
                            src_ds.RasterYSize, 1,
                            gdal.GetDataTypeByName(creation_type),
                            creation_options)

        dst_ds.SetGeoTransform(src_ds.GetGeoTransform())
        dst_ds.SetProjection(src_ds.GetProjectionRef())

        dstband = dst_ds.GetRasterBand(1)

    # =============================================================================
    #    Invoke algorithm.
    # =============================================================================

    if quiet_flag:
        prog_func = None
    else:
        prog_func = gdal.TermProgress_nocb

    gdal.ComputeProximity(srcband, dstband, options, callback=prog_func)

    srcband = None
    dstband = None
    src_ds = None
    dst_ds = None
예제 #15
0
def Calc(calc: MaybeSequence[str], outfile: Optional[PathLikeOrStr] = None, NoDataValue: Optional[Number] = None,
         type: Optional[Union[GDALDataType, str]] = None, format: Optional[str] = None,
         creation_options: Optional[Sequence[str]] = None, allBands: str = '', overwrite: bool = False,
         hideNoData: bool = False, projectionCheck: bool = False,
         color_table: Optional[ColorTableLike] = None,
         extent: Optional[Extent] = None, projwin: Optional[Union[Tuple, GeoRectangle]] = None,
         user_namespace: Optional[Dict]=None,
         debug: bool = False, quiet: bool = False, **input_files):

    if debug:
        print(f"gdal_calc.py starting calculation {calc}")

    # Single calc value compatibility
    if isinstance(calc, (list, tuple)):
        calc = calc
    else:
        calc = [calc]
    calc = [c.strip('"') for c in calc]

    creation_options = creation_options or []

    # set up global namespace for eval with all functions of gdal_array, numpy
    global_namespace = {key: getattr(module, key)
                        for module in [gdal_array, numpy] for key in dir(module) if not key.startswith('__')}

    if user_namespace:
        global_namespace.update(user_namespace)

    if not calc:
        raise Exception("No calculation provided.")
    elif not outfile and format.upper() != 'MEM':
        raise Exception("No output file provided.")

    if format is None:
        format = GetOutputDriverFor(outfile)

    if isinstance(extent, GeoRectangle):
        pass
    elif projwin:
        if isinstance(projwin, GeoRectangle):
            extent = projwin
        else:
            extent = GeoRectangle.from_lurd(*projwin)
    elif not extent:
        extent = Extent.IGNORE
    else:
        extent = extent_util.parse_extent(extent)

    compatible_gt_eps = 0.000001
    gt_diff_support = {
        GT.INCOMPATIBLE_OFFSET: extent != Extent.FAIL,
        GT.INCOMPATIBLE_PIXEL_SIZE: False,
        GT.INCOMPATIBLE_ROTATION: False,
        GT.NON_ZERO_ROTATION: False,
    }
    gt_diff_error = {
        GT.INCOMPATIBLE_OFFSET: 'different offset',
        GT.INCOMPATIBLE_PIXEL_SIZE: 'different pixel size',
        GT.INCOMPATIBLE_ROTATION: 'different rotation',
        GT.NON_ZERO_ROTATION: 'non zero rotation',
    }

    ################################################################
    # fetch details of input layers
    ################################################################

    # set up some lists to store data for each band
    myFileNames = []  # input filenames
    myFiles = []  # input DataSets
    myBands = []  # input bands
    myAlphaList = []  # input alpha letter that represents each input file
    myDataType = []  # string representation of the datatype of each input file
    myDataTypeNum = []  # datatype of each input file
    myNDV = []  # nodatavalue for each input file
    DimensionsCheck = None  # dimensions of the output
    Dimensions = []  # Dimensions of input files
    ProjectionCheck = None  # projection of the output
    GeoTransformCheck = None  # GeoTransform of the output
    GeoTransforms = []  # GeoTransform of each input file
    GeoTransformDiffer = False  # True if we have inputs with different GeoTransforms
    myTempFileNames = []  # vrt filename from each input file
    myAlphaFileLists = []  # list of the Alphas which holds a list of inputs

    # loop through input files - checking dimensions
    for alphas, filenames in input_files.items():
        if isinstance(filenames, (list, tuple)):
            # alpha is a list of files
            myAlphaFileLists.append(alphas)
        elif is_path_like(filenames) or isinstance(filenames, gdal.Dataset):
            # alpha is a single filename or a Dataset
            filenames = [filenames]
            alphas = [alphas]
        else:
            # I guess this alphas should be in the global_namespace,
            # It would have been better to pass it as user_namepsace, but I'll accept it anyway
            global_namespace[alphas] = filenames
            continue
        for alpha, filename in zip(alphas * len(filenames), filenames):
            if not alpha.endswith("_band"):
                # check if we have asked for a specific band...
                alpha_band = f"{alpha}_band"
                if alpha_band in input_files:
                    myBand = input_files[alpha_band]
                else:
                    myBand = 1

                myF_is_ds = not is_path_like(filename)
                if myF_is_ds:
                    myFile = filename
                    filename = None
                else:
                    myFile = open_ds(filename, gdal.GA_ReadOnly)
                if not myFile:
                    raise IOError(f"No such file or directory: '{filename}'")

                myFileNames.append(filename)
                myFiles.append(myFile)
                myBands.append(myBand)
                myAlphaList.append(alpha)
                dt = myFile.GetRasterBand(myBand).DataType
                myDataType.append(gdal.GetDataTypeName(dt))
                myDataTypeNum.append(dt)
                myNDV.append(None if hideNoData else myFile.GetRasterBand(myBand).GetNoDataValue())

                # check that the dimensions of each layer are the same
                myFileDimensions = [myFile.RasterXSize, myFile.RasterYSize]
                if DimensionsCheck:
                    if DimensionsCheck != myFileDimensions:
                        GeoTransformDiffer = True
                        if extent in [Extent.IGNORE, Extent.FAIL]:
                            raise Exception(
                                f"Error! Dimensions of file {filename} ({myFileDimensions[0]:d}, "
                                f"{myFileDimensions[1]:d}) are different from other files "
                                f"({DimensionsCheck[0]:d}, {DimensionsCheck[1]:d}).  Cannot proceed")
                else:
                    DimensionsCheck = myFileDimensions

                # check that the Projection of each layer are the same
                myProjection = myFile.GetProjection()
                if ProjectionCheck:
                    if projectionCheck and ProjectionCheck != myProjection:
                        raise Exception(
                            f"Error! Projection of file {filename} {myProjection} "
                            f"are different from other files {ProjectionCheck}.  Cannot proceed")
                else:
                    ProjectionCheck = myProjection

                # check that the GeoTransforms of each layer are the same
                myFileGeoTransform = myFile.GetGeoTransform(can_return_null=True)
                if extent == Extent.IGNORE:
                    GeoTransformCheck = myFileGeoTransform
                else:
                    Dimensions.append(myFileDimensions)
                    GeoTransforms.append(myFileGeoTransform)
                    if not GeoTransformCheck:
                        GeoTransformCheck = myFileGeoTransform
                    else:
                        my_gt_diff = extent_util.gt_diff(GeoTransformCheck, myFileGeoTransform, eps=compatible_gt_eps,
                                                         diff_support=gt_diff_support)
                        if my_gt_diff not in [GT.SAME, GT.ALMOST_SAME]:
                            GeoTransformDiffer = True
                            if my_gt_diff != GT.COMPATIBLE_DIFF:
                                raise Exception(
                                    f"Error! GeoTransform of file {filename} {myFileGeoTransform} is incompatible "
                                    f"({gt_diff_error[my_gt_diff]}), first file GeoTransform is {GeoTransformCheck}. "
                                    f"Cannot proceed")
                if debug:
                    print(
                        f"file {alpha}: {filename}, dimensions: "
                        f"{DimensionsCheck[0]}, {DimensionsCheck[1]}, type: {myDataType[-1]}")

    # process allBands option
    allBandsIndex = None
    allBandsCount = 1
    if allBands:
        if len(calc) > 1:
            raise Exception("Error! --allBands implies a single --calc")
        try:
            allBandsIndex = myAlphaList.index(allBands)
        except ValueError:
            raise Exception(f"Error! allBands option was given but Band {allBands} not found.  Cannot proceed")
        allBandsCount = myFiles[allBandsIndex].RasterCount
        if allBandsCount <= 1:
            allBandsIndex = None
    else:
        allBandsCount = len(calc)

    if extent not in [Extent.IGNORE, Extent.FAIL] and (
        GeoTransformDiffer or isinstance(extent, GeoRectangle)):
        # mixing different GeoTransforms/Extents
        GeoTransformCheck, DimensionsCheck, ExtentCheck = extent_util.calc_geotransform_and_dimensions(
            GeoTransforms, Dimensions, extent)
        if GeoTransformCheck is None:
            raise Exception("Error! The requested extent is empty. Cannot proceed")
        for i in range(len(myFileNames)):
            temp_vrt_filename, temp_vrt_ds = extent_util.make_temp_vrt(myFiles[i], ExtentCheck)
            myTempFileNames.append(temp_vrt_filename)
            myFiles[i] = None  # close original ds
            myFiles[i] = temp_vrt_ds  # replace original ds with vrt_ds

            # update the new precise dimensions and gt from the new ds
            GeoTransformCheck = temp_vrt_ds.GetGeoTransform()
            DimensionsCheck = [temp_vrt_ds.RasterXSize, temp_vrt_ds.RasterYSize]
        temp_vrt_ds = None

    ################################################################
    # set up output file
    ################################################################

    # open output file exists
    if outfile and os.path.isfile(outfile) and not overwrite:
        if allBandsIndex is not None:
            raise Exception("Error! allBands option was given but Output file exists, must use --overwrite option!")
        if len(calc) > 1:
            raise Exception(
                "Error! multiple calc options were given but Output file exists, must use --overwrite option!")
        if debug:
            print(f"Output file {outfile} exists - filling in results into file")

        myOut = open_ds(outfile, gdal.GA_Update)
        if myOut is None:
            error = 'but cannot be opened for update'
        elif [myOut.RasterXSize, myOut.RasterYSize] != DimensionsCheck:
            error = 'but is the wrong size'
        elif ProjectionCheck and ProjectionCheck != myOut.GetProjection():
            error = 'but is the wrong projection'
        elif GeoTransformCheck and GeoTransformCheck != myOut.GetGeoTransform(can_return_null=True):
            error = 'but is the wrong geotransform'
        else:
            error = None
        if error:
            raise Exception(
                f"Error! Output exists, {error}.  Use the --overwrite option "
                f"to automatically overwrite the existing file")

        myOutB = myOut.GetRasterBand(1)
        myOutNDV = myOutB.GetNoDataValue()
        myOutType = myOutB.DataType

    else:
        if outfile:
            # remove existing file and regenerate
            if os.path.isfile(outfile):
                os.remove(outfile)
            # create a new file
            if debug:
                print(f"Generating output file {outfile}")
        else:
            outfile = ''

        # find data type to use
        if not type:
            # use the largest type of the input files
            myOutType = max(myDataTypeNum)
        else:
            myOutType = type
            if isinstance(myOutType, str):
                myOutType = gdal.GetDataTypeByName(myOutType)

        # create file
        myOutDrv = gdal.GetDriverByName(format)
        myOut = myOutDrv.Create(
            os.fspath(outfile), DimensionsCheck[0], DimensionsCheck[1], allBandsCount,
            myOutType, creation_options)

        # set output geo info based on first input layer
        if not GeoTransformCheck:
            GeoTransformCheck = myFiles[0].GetGeoTransform(can_return_null=True)
        if GeoTransformCheck:
            myOut.SetGeoTransform(GeoTransformCheck)

        if not ProjectionCheck:
            ProjectionCheck = myFiles[0].GetProjection()
        if ProjectionCheck:
            myOut.SetProjection(ProjectionCheck)

        if NoDataValue is None:
            myOutNDV = None if hideNoData else DefaultNDVLookup[
                myOutType]  # use the default noDataValue for this datatype
        elif isinstance(NoDataValue, str) and NoDataValue.lower() == 'none':
            myOutNDV = None  # not to set any noDataValue
        else:
            myOutNDV = NoDataValue  # use the given noDataValue

        for i in range(1, allBandsCount + 1):
            myOutB = myOut.GetRasterBand(i)
            if myOutNDV is not None:
                myOutB.SetNoDataValue(myOutNDV)
            if color_table:
                # set color table and color interpretation
                if is_path_like(color_table):
                    color_table = get_color_table(color_table)
                myOutB.SetRasterColorTable(color_table)
                myOutB.SetRasterColorInterpretation(gdal.GCI_PaletteIndex)

            myOutB = None  # write to band

    myOutTypeName = gdal.GetDataTypeName(myOutType)
    if debug:
        print(f"output file: {outfile}, dimensions: {myOut.RasterXSize}, {myOut.RasterYSize}, type: {myOutTypeName}")

    ################################################################
    # find block size to chop grids into bite-sized chunks
    ################################################################

    # use the block size of the first layer to read efficiently
    myBlockSize = myFiles[0].GetRasterBand(myBands[0]).GetBlockSize()
    # find total x and y blocks to be read
    nXBlocks = (int)((DimensionsCheck[0] + myBlockSize[0] - 1) / myBlockSize[0])
    nYBlocks = (int)((DimensionsCheck[1] + myBlockSize[1] - 1) / myBlockSize[1])
    myBufSize = myBlockSize[0] * myBlockSize[1]

    if debug:
        print(f"using blocksize {myBlockSize[0]} x {myBlockSize[1]}")

    # variables for displaying progress
    ProgressCt = -1
    ProgressMk = -1
    ProgressEnd = nXBlocks * nYBlocks * allBandsCount

    ################################################################
    # start looping through each band in allBandsCount
    ################################################################

    for bandNo in range(1, allBandsCount + 1):

        ################################################################
        # start looping through blocks of data
        ################################################################

        # store these numbers in variables that may change later
        nXValid = myBlockSize[0]
        nYValid = myBlockSize[1]

        # loop through X-lines
        for X in range(0, nXBlocks):

            # in case the blocks don't fit perfectly
            # change the block size of the final piece
            if X == nXBlocks - 1:
                nXValid = DimensionsCheck[0] - X * myBlockSize[0]

            # find X offset
            myX = X * myBlockSize[0]

            # reset buffer size for start of Y loop
            nYValid = myBlockSize[1]
            myBufSize = nXValid * nYValid

            # loop through Y lines
            for Y in range(0, nYBlocks):
                ProgressCt += 1
                if 10 * ProgressCt / ProgressEnd % 10 != ProgressMk and not quiet:
                    ProgressMk = 10 * ProgressCt / ProgressEnd % 10
                    from sys import version_info
                    if version_info >= (3, 0, 0):
                        exec('print("%d.." % (10*ProgressMk), end=" ")')
                    else:
                        exec('print 10*ProgressMk, "..",')

                # change the block size of the final piece
                if Y == nYBlocks - 1:
                    nYValid = DimensionsCheck[1] - Y * myBlockSize[1]
                    myBufSize = nXValid * nYValid

                # find Y offset
                myY = Y * myBlockSize[1]

                # create empty buffer to mark where nodata occurs
                myNDVs = None

                # make local namespace for calculation
                local_namespace = {}

                val_lists = defaultdict(list)

                # fetch data for each input layer
                for i, Alpha in enumerate(myAlphaList):

                    # populate lettered arrays with values
                    if allBandsIndex is not None and allBandsIndex == i:
                        myBandNo = bandNo
                    else:
                        myBandNo = myBands[i]
                    myval = gdal_array.BandReadAsArray(myFiles[i].GetRasterBand(myBandNo),
                                                       xoff=myX, yoff=myY,
                                                       win_xsize=nXValid, win_ysize=nYValid)
                    if myval is None:
                        raise Exception(f'Input block reading failed from filename {filename[i]}')

                    # fill in nodata values
                    if myNDV[i] is not None:
                        # myNDVs is a boolean buffer.
                        # a cell equals to 1 if there is NDV in any of the corresponding cells in input raster bands.
                        if myNDVs is None:
                            # this is the first band that has NDV set. we initializes myNDVs to a zero buffer
                            # as we didn't see any NDV value yet.
                            myNDVs = numpy.zeros(myBufSize)
                            myNDVs.shape = (nYValid, nXValid)
                        myNDVs = 1 * numpy.logical_or(myNDVs == 1, myval == myNDV[i])

                    # add an array of values for this block to the eval namespace
                    if Alpha in myAlphaFileLists:
                        val_lists[Alpha].append(myval)
                    else:
                        local_namespace[Alpha] = myval
                    myval = None

                for lst in myAlphaFileLists:
                    local_namespace[lst] = val_lists[lst]

                # try the calculation on the array blocks
                this_calc = calc[bandNo - 1 if len(calc) > 1 else 0]
                try:
                    myResult = eval(this_calc, global_namespace, local_namespace)
                except:
                    print(f"evaluation of calculation {this_calc} failed")
                    raise

                # Propagate nodata values (set nodata cells to zero
                # then add nodata value to these cells).
                if myNDVs is not None and myOutNDV is not None:
                    myResult = ((1 * (myNDVs == 0)) * myResult) + (myOutNDV * myNDVs)
                elif not isinstance(myResult, numpy.ndarray):
                    myResult = numpy.ones((nYValid, nXValid)) * myResult

                # write data block to the output file
                myOutB = myOut.GetRasterBand(bandNo)
                if gdal_array.BandWriteArray(myOutB, myResult, xoff=myX, yoff=myY) != 0:
                    raise Exception('Block writing failed')
                myOutB = None  # write to band

    # remove temp files
    for idx, tempFile in enumerate(myTempFileNames):
        myFiles[idx] = None
        os.remove(tempFile)

    gdal.ErrorReset()
    myOut.FlushCache()
    if gdal.GetLastErrorMsg() != '':
        raise Exception('Dataset writing failed')

    if not quiet:
        print("100 - Done")

    return myOut
예제 #16
0
파일: ogrmerge.py 프로젝트: edzer/gdal
def ogrmerge(
    src_datasets: Optional[Sequence[str]] = None,
    dst_filename: Optional[PathLikeOrStr] = None,
    driver_name: Optional[str] = None,
    overwrite_ds: bool = False,
    overwrite_layer: bool = False,
    update: bool = False,
    append: bool = False,
    single_layer: bool = False,
    layer_name_template: Optional[str]  = None,
    skip_failures: bool = False,
    src_geom_types: Optional[Sequence[int]] = None,
    field_strategy: Optional[str] = None,
    src_layer_field_name: Optional[str] = None,
    src_layer_field_content: Optional[str] = None,
    a_srs: Optional[str] = None,
    s_srs: Optional[str] = None,
    t_srs: Optional[str] = None,
    dsco: Optional[Sequence[str]] = None,
    lco: Optional[Sequence[str]] = None,
    progress_callback: Optional = None, progress_arg: Optional = None):

    src_datasets = src_datasets or []
    src_geom_types = src_geom_types or []
    dsco = dsco or []
    lco = lco or []
    if update:
        if driver_name is not None:
            print('ERROR: -f incompatible with -update')
            return 1
        if dsco:
            print('ERROR: -dsco incompatible with -update')
            return 1
        driver_name = ''
    else:
        if driver_name is None:
            driver_name = GetOutputDriverFor(dst_filename, is_raster=False)

    if src_layer_field_content is None:
        src_layer_field_content = '{AUTO_NAME}'
    elif src_layer_field_name is None:
        src_layer_field_name = 'source_ds_lyr'

    if not single_layer and driver_name == 'ESRI Shapefile' and \
       dst_filename.lower().endswith('.shp'):
        print('ERROR: Non-single layer mode incompatible with non-directory '
              'shapefile output')
        return 1

    if not src_datasets:
        print('ERROR: No source datasets')
        return 1

    if layer_name_template is None:
        if single_layer:
            layer_name_template = 'merged'
        else:
            layer_name_template = '{AUTO_NAME}'

    vrt_filename = None
    if not EQUAL(driver_name, 'VRT'):
        dst_ds = gdal.OpenEx(dst_filename, gdal.OF_VECTOR | gdal.OF_UPDATE)
        if dst_ds is not None:
            if not update and not overwrite_ds:
                print('ERROR: Destination dataset already exists, ' +
                      'but -update nor -overwrite_ds are specified')
                return 1
            if overwrite_ds:
                drv = dst_ds.GetDriver()
                dst_ds = None
                if drv.GetDescription() == 'OGR_VRT':
                    # We don't want to destroy the sources of the VRT
                    gdal.Unlink(dst_filename)
                else:
                    drv.Delete(dst_filename)
        elif update:
            print('ERROR: Destination dataset does not exist')
            return 1
        if dst_ds is None:
            drv = gdal.GetDriverByName(driver_name)
            if drv is None:
                print('ERROR: Invalid driver: %s' % driver_name)
                return 1
            dst_ds = drv.Create(
                dst_filename, 0, 0, 0, gdal.GDT_Unknown, dsco)
            if dst_ds is None:
                return 1

        vrt_filename = '/vsimem/_ogrmerge_.vrt'
    else:
        if gdal.VSIStatL(dst_filename) and not overwrite_ds:
            print('ERROR: Destination dataset already exists, ' +
                  'but -overwrite_ds are specified')
            return 1
        vrt_filename = dst_filename

    f = gdal.VSIFOpenL(vrt_filename, 'wb')
    if f is None:
        print('ERROR: Cannot create %s' % vrt_filename)
        return 1

    writer = XMLWriter(f)
    writer.open_element('OGRVRTDataSource')

    if single_layer:

        ogr_vrt_union_layer_written = False

        for src_ds_idx, src_dsname in enumerate(src_datasets):
            src_ds = ogr.Open(src_dsname)
            if src_ds is None:
                print('ERROR: Cannot open %s' % src_dsname)
                if skip_failures:
                    continue
                gdal.VSIFCloseL(f)
                gdal.Unlink(vrt_filename)
                return 1
            for src_lyr_idx, src_lyr in enumerate(src_ds):
                if src_geom_types:
                    gt = ogr.GT_Flatten(src_lyr.GetGeomType())
                    if gt not in src_geom_types:
                        continue

                if not ogr_vrt_union_layer_written:
                    ogr_vrt_union_layer_written = True
                    writer.open_element('OGRVRTUnionLayer',
                                        attrs={'name': layer_name_template})

                    if src_layer_field_name is not None:
                        writer.write_element_value('SourceLayerFieldName',
                                                   src_layer_field_name)

                    if field_strategy is not None:
                        writer.write_element_value('FieldStrategy',
                                                   field_strategy)

                layer_name = src_layer_field_content

                src_lyr_name = src_lyr.GetName()
                try:
                    src_lyr_name = src_lyr_name.decode('utf-8')
                except AttributeError:
                    pass

                basename = None
                if os.path.exists(src_dsname):
                    basename = os.path.basename(src_dsname)
                    if '.' in basename:
                        basename = '.'.join(basename.split(".")[0:-1])

                if basename == src_lyr_name:
                    layer_name = layer_name.replace('{AUTO_NAME}', basename)
                elif basename is None:
                    layer_name = layer_name.replace(
                        '{AUTO_NAME}',
                        'Dataset%d_%s' % (src_ds_idx, src_lyr_name))
                else:
                    layer_name = layer_name.replace(
                        '{AUTO_NAME}', basename + '_' + src_lyr_name)

                if basename is not None:
                    layer_name = layer_name.replace('{DS_BASENAME}', basename)
                else:
                    layer_name = layer_name.replace('{DS_BASENAME}',
                                                    src_dsname)
                layer_name = layer_name.replace('{DS_NAME}', '%s' %
                                                src_dsname)
                layer_name = layer_name.replace('{DS_INDEX}', '%d' %
                                                src_ds_idx)
                layer_name = layer_name.replace('{LAYER_NAME}',
                                                src_lyr_name)
                layer_name = layer_name.replace('{LAYER_INDEX}', '%d' %
                                                src_lyr_idx)

                if t_srs is not None:
                    writer.open_element('OGRVRTWarpedLayer')

                writer.open_element('OGRVRTLayer',
                                    attrs={'name': layer_name})
                attrs = {}
                if EQUAL(driver_name, 'VRT') and \
                   os.path.exists(src_dsname) and \
                   not os.path.isabs(src_dsname) and \
                   '/' not in vrt_filename and \
                   '\\' not in vrt_filename:
                    attrs['relativeToVRT'] = '1'
                if single_layer:
                    attrs['shared'] = '1'
                writer.write_element_value('SrcDataSource', src_dsname,
                                           attrs=attrs)
                writer.write_element_value('SrcLayer', src_lyr.GetName())

                if a_srs is not None:
                    writer.write_element_value('LayerSRS', a_srs)

                writer.close_element('OGRVRTLayer')

                if t_srs is not None:
                    if s_srs is not None:
                        writer.write_element_value('SrcSRS', s_srs)

                    writer.write_element_value('TargetSRS', t_srs)

                    writer.close_element('OGRVRTWarpedLayer')

        if ogr_vrt_union_layer_written:
            writer.close_element('OGRVRTUnionLayer')

    else:

        for src_ds_idx, src_dsname in enumerate(src_datasets):
            src_ds = ogr.Open(src_dsname)
            if src_ds is None:
                print('ERROR: Cannot open %s' % src_dsname)
                if skip_failures:
                    continue
                gdal.VSIFCloseL(f)
                gdal.Unlink(vrt_filename)
                return 1
            for src_lyr_idx, src_lyr in enumerate(src_ds):
                if src_geom_types:
                    gt = ogr.GT_Flatten(src_lyr.GetGeomType())
                    if gt not in src_geom_types:
                        continue

                src_lyr_name = src_lyr.GetName()
                try:
                    src_lyr_name = src_lyr_name.decode('utf-8')
                except AttributeError:
                    pass

                layer_name = layer_name_template
                basename = None
                if os.path.exists(src_dsname):
                    basename = os.path.basename(src_dsname)
                    if '.' in basename:
                        basename = '.'.join(basename.split(".")[0:-1])

                if basename == src_lyr_name:
                    layer_name = layer_name.replace('{AUTO_NAME}', basename)
                elif basename is None:
                    layer_name = layer_name.replace(
                        '{AUTO_NAME}',
                        'Dataset%d_%s' % (src_ds_idx, src_lyr_name))
                else:
                    layer_name = layer_name.replace(
                        '{AUTO_NAME}', basename + '_' + src_lyr_name)

                if basename is not None:
                    layer_name = layer_name.replace('{DS_BASENAME}', basename)
                elif '{DS_BASENAME}' in layer_name:
                    if skip_failures:
                        if '{DS_INDEX}' not in layer_name:
                            layer_name = layer_name.replace(
                                '{DS_BASENAME}', 'Dataset%d' % src_ds_idx)
                    else:
                        print('ERROR: Layer name template %s '
                              'includes {DS_BASENAME} '
                              'but %s is not a file' %
                              (layer_name_template, src_dsname))

                        gdal.VSIFCloseL(f)
                        gdal.Unlink(vrt_filename)
                        return 1
                layer_name = layer_name.replace('{DS_NAME}', '%s' %
                                                src_dsname)
                layer_name = layer_name.replace('{DS_INDEX}', '%d' %
                                                src_ds_idx)
                layer_name = layer_name.replace('{LAYER_NAME}',
                                                src_lyr_name)
                layer_name = layer_name.replace('{LAYER_INDEX}', '%d' %
                                                src_lyr_idx)

                if t_srs is not None:
                    writer.open_element('OGRVRTWarpedLayer')

                writer.open_element('OGRVRTLayer',
                                    attrs={'name': layer_name})
                attrs = {}
                if EQUAL(driver_name, 'VRT') and \
                   os.path.exists(src_dsname) and \
                   not os.path.isabs(src_dsname) and \
                   '/' not in vrt_filename and \
                   '\\' not in vrt_filename:
                    attrs['relativeToVRT'] = '1'
                if single_layer:
                    attrs['shared'] = '1'
                writer.write_element_value('SrcDataSource', src_dsname,
                                           attrs=attrs)
                writer.write_element_value('SrcLayer', src_lyr_name)

                if a_srs is not None:
                    writer.write_element_value('LayerSRS', a_srs)

                writer.close_element('OGRVRTLayer')

                if t_srs is not None:
                    if s_srs is not None:
                        writer.write_element_value('SrcSRS', s_srs)

                    writer.write_element_value('TargetSRS', t_srs)

                    writer.close_element('OGRVRTWarpedLayer')

    writer.close_element('OGRVRTDataSource')

    gdal.VSIFCloseL(f)

    ret = 0
    if not EQUAL(driver_name, 'VRT'):
        accessMode = None
        if append:
            accessMode = 'append'
        elif overwrite_layer:
            accessMode = 'overwrite'
        ret = gdal.VectorTranslate(dst_ds, vrt_filename,
                                   accessMode=accessMode,
                                   layerCreationOptions=lco,
                                   skipFailures=skip_failures,
                                   callback=progress_callback,
                                   callback_data=progress_arg)
        if ret == 1:
            ret = 0
        else:
            ret = 1
        gdal.Unlink(vrt_filename)

    return ret
예제 #17
0
def gdal_merge(argv=None):
    verbose = 0
    quiet = 0
    names = []
    driver_name = None
    out_file = 'out.tif'

    ulx = None
    psize_x = None
    separate = 0
    copy_pct = 0
    nodata = None
    a_nodata = None
    create_options = []
    pre_init = []
    band_type = None
    createonly = 0
    bTargetAlignedPixels = False
    start_time = time.time()

    if argv is None:
        argv = argv
    argv = gdal.GeneralCmdLineProcessor(argv)
    if argv is None:
        return 0

    # Parse command line arguments.
    i = 1
    while i < len(argv):
        arg = argv[i]

        if arg == '-o':
            i = i + 1
            out_file = argv[i]

        elif arg == '-v':
            verbose = 1

        elif arg == '-q' or arg == '-quiet':
            quiet = 1

        elif arg == '-createonly':
            createonly = 1

        elif arg == '-separate':
            separate = 1

        elif arg == '-seperate':
            separate = 1

        elif arg == '-pct':
            copy_pct = 1

        elif arg == '-ot':
            i = i + 1
            band_type = gdal.GetDataTypeByName(argv[i])
            if band_type == gdal.GDT_Unknown:
                print('Unknown GDAL data type: %s' % argv[i])
                return 1

        elif arg == '-init':
            i = i + 1
            str_pre_init = argv[i].split()
            for x in str_pre_init:
                pre_init.append(float(x))

        elif arg == '-n':
            i = i + 1
            nodata = float(argv[i])

        elif arg == '-a_nodata':
            i = i + 1
            a_nodata = float(argv[i])

        elif arg == '-f' or arg == '-of':
            i = i + 1
            driver_name = argv[i]

        elif arg == '-co':
            i = i + 1
            create_options.append(argv[i])

        elif arg == '-ps':
            psize_x = float(argv[i + 1])
            psize_y = -1 * abs(float(argv[i + 2]))
            i = i + 2

        elif arg == '-tap':
            bTargetAlignedPixels = True

        elif arg == '-ul_lr':
            ulx = float(argv[i + 1])
            uly = float(argv[i + 2])
            lrx = float(argv[i + 3])
            lry = float(argv[i + 4])
            i = i + 4

        elif arg[:1] == '-':
            print('Unrecognized command option: %s' % arg)
            return Usage()

        else:
            names.append(arg)

        i = i + 1

    if not names:
        print('No input files selected.')
        return Usage()

    if driver_name is None:
        driver_name = GetOutputDriverFor(out_file)

    driver = gdal.GetDriverByName(driver_name)
    if driver is None:
        print('Format driver %s not found, pick a supported driver.' %
              driver_name)
        return 1

    DriverMD = driver.GetMetadata()
    if 'DCAP_CREATE' not in DriverMD:
        print(
            'Format driver %s does not support creation and piecewise writing.\nPlease select a format that does, such as GTiff (the default) or HFA (Erdas Imagine).'
            % driver_name)
        return 1

    # Collect information on all the source files.
    file_infos = names_to_fileinfos(names)

    if ulx is None:
        ulx = file_infos[0].ulx
        uly = file_infos[0].uly
        lrx = file_infos[0].lrx
        lry = file_infos[0].lry

        for fi in file_infos:
            ulx = min(ulx, fi.ulx)
            uly = max(uly, fi.uly)
            lrx = max(lrx, fi.lrx)
            lry = min(lry, fi.lry)

    if psize_x is None:
        psize_x = file_infos[0].geotransform[1]
        psize_y = file_infos[0].geotransform[5]

    if band_type is None:
        band_type = file_infos[0].band_type

    # Try opening as an existing file.
    gdal.PushErrorHandler('CPLQuietErrorHandler')
    t_fh = gdal.Open(out_file, gdal.GA_Update)
    gdal.PopErrorHandler()

    # Create output file if it does not already exist.
    if t_fh is None:

        if bTargetAlignedPixels:
            ulx = math.floor(ulx / psize_x) * psize_x
            lrx = math.ceil(lrx / psize_x) * psize_x
            lry = math.floor(lry / -psize_y) * -psize_y
            uly = math.ceil(uly / -psize_y) * -psize_y

        geotransform = [ulx, psize_x, 0, uly, 0, psize_y]

        xsize = int((lrx - ulx) / geotransform[1] + 0.5)
        ysize = int((lry - uly) / geotransform[5] + 0.5)

        if separate != 0:
            bands = 0

            for fi in file_infos:
                bands = bands + fi.bands
        else:
            bands = file_infos[0].bands

        t_fh = driver.Create(out_file, xsize, ysize, bands, band_type,
                             create_options)
        if t_fh is None:
            print('Creation failed, terminating gdal_merge.')
            return 1

        t_fh.SetGeoTransform(geotransform)
        t_fh.SetProjection(file_infos[0].projection)

        if copy_pct:
            t_fh.GetRasterBand(1).SetRasterColorTable(file_infos[0].ct)
    else:
        if separate != 0:
            bands = 0
            for fi in file_infos:
                bands = bands + fi.bands
            if t_fh.RasterCount < bands:
                print(
                    'Existing output file has less bands than the input files. You should delete it before. Terminating gdal_merge.'
                )
                return 1
        else:
            bands = min(file_infos[0].bands, t_fh.RasterCount)

    # Do we need to set nodata value ?
    if a_nodata is not None:
        for i in range(t_fh.RasterCount):
            t_fh.GetRasterBand(i + 1).SetNoDataValue(a_nodata)

    # Do we need to pre-initialize the whole mosaic file to some value?
    if pre_init is not None:
        if t_fh.RasterCount <= len(pre_init):
            for i in range(t_fh.RasterCount):
                t_fh.GetRasterBand(i + 1).Fill(pre_init[i])
        elif len(pre_init) == 1:
            for i in range(t_fh.RasterCount):
                t_fh.GetRasterBand(i + 1).Fill(pre_init[0])

    # Copy data from source files into output file.
    t_band = 1

    if quiet == 0 and verbose == 0:
        progress(0.0)
    fi_processed = 0

    for fi in file_infos:
        if createonly != 0:
            continue

        if verbose != 0:
            print("")
            print(
                "Processing file %5d of %5d, %6.3f%% completed in %d minutes."
                %
                (fi_processed + 1, len(file_infos), fi_processed * 100.0 /
                 len(file_infos), int(round(
                     (time.time() - start_time) / 60.0))))
            fi.report()

        if separate == 0:
            for band in range(1, bands + 1):
                fi.copy_into(t_fh, band, band, nodata, verbose)
        else:
            for band in range(1, fi.bands + 1):
                fi.copy_into(t_fh, band, t_band, nodata, verbose)
                t_band = t_band + 1

        fi_processed = fi_processed + 1
        if quiet == 0 and verbose == 0:
            progress(fi_processed / float(len(file_infos)))

    # Force file to be closed.
    t_fh = None