コード例 #1
0
def test_misc_9():

    old_val = gdal.GetCacheMax()
    gdal.SetCacheMax(3000000000)
    ret_val = gdal.GetCacheMax()
    gdal.SetCacheMax(old_val)

    assert ret_val == 3000000000, 'did not get expected value'
コード例 #2
0
ファイル: misc.py プロジェクト: shinerbass/gdal
def misc_9():

    old_val = gdal.GetCacheMax()
    gdal.SetCacheMax(3000000000)
    ret_val = gdal.GetCacheMax()
    gdal.SetCacheMax(old_val)

    if ret_val != 3000000000:
        gdaltest.post_reason('did not get expected value')
        print(ret_val)
        return 'fail'

    return 'success'
コード例 #3
0
def mask_23():

    drv = gdal.GetDriverByName('GTiff')
    md = drv.GetMetadata()
    if md['DMD_CREATIONOPTIONLIST'].find('JPEG') == -1:
        return 'skip'

    src_ds = drv.Create('tmp/mask_23_src.tif', 3000, 2000, 3, options=['TILED=YES', 'SPARSE_OK=YES'])
    src_ds.CreateMaskBand(gdal.GMF_PER_DATASET)

    gdal.SetConfigOption('GDAL_TIFF_INTERNAL_MASK', 'YES')
    old_val = gdal.GetCacheMax()
    gdal.SetCacheMax(15000000)
    gdal.ErrorReset()
    ds = drv.CreateCopy('tmp/mask_23_dst.tif', src_ds, options=['TILED=YES', 'COMPRESS=JPEG'])
    gdal.SetConfigOption('GDAL_TIFF_INTERNAL_MASK', 'NO')
    gdal.SetCacheMax(old_val)

    del ds
    error_msg = gdal.GetLastErrorMsg()
    src_ds = None

    drv.Delete('tmp/mask_23_src.tif')
    drv.Delete('tmp/mask_23_dst.tif')

    # 'ERROR 1: TIFFRewriteDirectory:Error fetching directory count' was triggered before
    if error_msg != '':
        return 'fail'

    return 'success'
コード例 #4
0
def basic_test_7_internal():
    try:
        gdal.Open('non_existing_ds', gdal.GA_ReadOnly)
        gdaltest.post_reason('opening should have thrown an exception')
        return 'fail'
    except:
        # Special case: we should still be able to get the error message
        # until we call a new GDAL function
        if not matches_non_existing_error_msg(gdal.GetLastErrorMsg()):
            gdaltest.post_reason('did not get expected error message, got %s' %
                                 gdal.GetLastErrorMsg())
            return 'fail'

        if gdal.GetLastErrorType() == 0:
            gdaltest.post_reason('did not get expected error type')
            return 'fail'

        # Should issue an implicit CPLErrorReset()
        gdal.GetCacheMax()

        if gdal.GetLastErrorType() != 0:
            gdaltest.post_reason('got unexpected error type')
            return 'fail'

        return 'success'
コード例 #5
0
def basic_test_7_internal():
    try:
        gdal.Open('non_existing_ds', gdal.GA_ReadOnly)
        gdaltest.post_reason('opening should have thrown an exception')
        return 'fail'
    except:
        # Special case: we should still be able to get the error message
        # until we call a new GDAL function
        if gdal.GetLastErrorMsg(
        ) != '`non_existing_ds\' does not exist in the file system,\nand is not recognised as a supported dataset name.\n':
            gdaltest.post_reason('did not get expected error message')
            return 'fail'

        if gdal.GetLastErrorType() == 0:
            gdaltest.post_reason('did not get expected error type')
            return 'fail'

        # Should issue an implicit CPLErrorReset()
        gdal.GetCacheMax()

        if gdal.GetLastErrorType() != 0:
            gdaltest.post_reason('got unexpected error type')
            return 'fail'

        return 'success'
コード例 #6
0
ファイル: jpeg.py プロジェクト: idixon/gdal-1
def test_jpeg_18():
    height = 1024
    width = 1024
    src_ds = gdal.GetDriverByName('GTiff').Create('/vsimem/jpeg_18.tif', width,
                                                  height, 1)
    for i in range(height):
        data = struct.pack('B' * 1, int(i / (height / 256)))
        src_ds.WriteRaster(0, i, width, 1, data, 1, 1)

    ds = gdal.GetDriverByName('JPEG').CreateCopy('/vsimem/jpeg_18.jpg',
                                                 src_ds,
                                                 options=['QUALITY=99'])
    src_ds = None
    gdal.Unlink('/vsimem/jpeg_18.tif')

    oldSize = gdal.GetCacheMax()
    gdal.SetCacheMax(0)

    line0 = ds.GetRasterBand(1).ReadRaster(0, 0, width, 1)
    data = struct.unpack('B' * width, line0)
    assert abs(data[0] - 0) <= 10
    line1023 = ds.GetRasterBand(1).ReadRaster(0, height - 1, width, 1)
    data = struct.unpack('B' * width, line1023)
    assert abs(data[0] - 255) <= 10
    line0_ovr1 = ds.GetRasterBand(1).GetOverview(1).ReadRaster(
        0, 0, int(width / 4), 1)
    data = struct.unpack('B' * (int(width / 4)), line0_ovr1)
    assert abs(data[0] - 0) <= 10
    line1023_bis = ds.GetRasterBand(1).ReadRaster(0, height - 1, width, 1)
    assert line1023_bis != line0 and line1023 == line1023_bis
    line0_bis = ds.GetRasterBand(1).ReadRaster(0, 0, width, 1)
    assert line0 == line0_bis
    line255_ovr1 = ds.GetRasterBand(1).GetOverview(1).ReadRaster(
        0,
        int(height / 4) - 1, int(width / 4), 1)
    data = struct.unpack('B' * int(width / 4), line255_ovr1)
    assert abs(data[0] - 255) <= 10
    line0_bis = ds.GetRasterBand(1).ReadRaster(0, 0, width, 1)
    assert line0 == line0_bis
    line0_ovr1_bis = ds.GetRasterBand(1).GetOverview(1).ReadRaster(
        0, 0, int(width / 4), 1)
    assert line0_ovr1 == line0_ovr1_bis
    line255_ovr1_bis = ds.GetRasterBand(1).GetOverview(1).ReadRaster(
        0,
        int(height / 4) - 1, int(width / 4), 1)
    assert line255_ovr1 == line255_ovr1_bis

    gdal.SetCacheMax(oldSize)

    ds = None
    gdal.Unlink('/vsimem/jpeg_18.jpg')
コード例 #7
0
    def gdal_config_options(self, cmd=''):
        extra_args = []

        if 'GDAL_CACHEMAX' not in cmd:
            value = gdal.GetCacheMax()
            extra_args.extend(('--config', 'GDAL_CACHEMAX', str(value)))

        for key in ('CPL_DEBUG', 'GDAL_SKIP', 'GDAL_DATA',
                    'GDAL_DRIVER_PATH', 'OGR_DRIVER_PATH'):
            if key not in cmd:
                value = gdal.GetConfigOption(key, None)
                if value:
                    extra_args.extend(('--config', key, '"%s"' % value))

        return extra_args
コード例 #8
0
ファイル: basic_test.py プロジェクト: youngpm/gdal
def basic_test_7_internal():

    with pytest.raises(Exception):
        gdal.Open('non_existing_ds', gdal.GA_ReadOnly)

    # Special case: we should still be able to get the error message
    # until we call a new GDAL function
    assert matches_non_existing_error_msg(gdal.GetLastErrorMsg()), ('did not get expected error message, got %s' % gdal.GetLastErrorMsg())

    # Special case: we should still be able to get the error message
    # until we call a new GDAL function
    assert matches_non_existing_error_msg(gdal.GetLastErrorMsg()), 'did not get expected error message, got %s' % gdal.GetLastErrorMsg()

    assert gdal.GetLastErrorType() != 0, 'did not get expected error type'

    # Should issue an implicit CPLErrorReset()
    gdal.GetCacheMax()

    assert gdal.GetLastErrorType() == 0, 'got unexpected error type'
コード例 #9
0
def jpeg_18():
    height = 1024
    width = 1024
    src_ds = gdal.GetDriverByName('GTiff').Create('/vsimem/jpeg_18.tif', width,
                                                  height, 1)
    for i in range(height):
        data = struct.pack('B' * 1, int(i / (height / 256)))
        src_ds.WriteRaster(0, i, width, 1, data, 1, 1)

    ds = gdal.GetDriverByName('JPEG').CreateCopy('/vsimem/jpeg_18.jpg',
                                                 src_ds,
                                                 options=['QUALITY=99'])
    src_ds = None
    gdal.Unlink('/vsimem/jpeg_18.tif')

    oldSize = gdal.GetCacheMax()
    gdal.SetCacheMax(0)

    line0 = ds.GetRasterBand(1).ReadRaster(0, 0, width, 1)
    data = struct.unpack('B' * width, line0)
    if abs(data[0] - 0) > 10:
        return 'fail'
    line1023 = ds.GetRasterBand(1).ReadRaster(0, height - 1, width, 1)
    data = struct.unpack('B' * width, line1023)
    if abs(data[0] - 255) > 10:
        return 'fail'
    line0_ovr1 = ds.GetRasterBand(1).GetOverview(1).ReadRaster(
        0, 0, int(width / 4), 1)
    data = struct.unpack('B' * (int(width / 4)), line0_ovr1)
    if abs(data[0] - 0) > 10:
        return 'fail'
    line1023_bis = ds.GetRasterBand(1).ReadRaster(0, height - 1, width, 1)
    if line1023_bis == line0 or line1023 != line1023_bis:
        gdaltest.post_reason('fail')
        return 'fail'
    line0_bis = ds.GetRasterBand(1).ReadRaster(0, 0, width, 1)
    if line0 != line0_bis:
        gdaltest.post_reason('fail')
        return 'fail'
    line255_ovr1 = ds.GetRasterBand(1).GetOverview(1).ReadRaster(
        0,
        int(height / 4) - 1, int(width / 4), 1)
    data = struct.unpack('B' * int(width / 4), line255_ovr1)
    if abs(data[0] - 255) > 10:
        return 'fail'
    line0_bis = ds.GetRasterBand(1).ReadRaster(0, 0, width, 1)
    if line0 != line0_bis:
        gdaltest.post_reason('fail')
        return 'fail'
    line0_ovr1_bis = ds.GetRasterBand(1).GetOverview(1).ReadRaster(
        0, 0, int(width / 4), 1)
    if line0_ovr1 != line0_ovr1_bis:
        gdaltest.post_reason('fail')
        return 'fail'
    line255_ovr1_bis = ds.GetRasterBand(1).GetOverview(1).ReadRaster(
        0,
        int(height / 4) - 1, int(width / 4), 1)
    if line255_ovr1 != line255_ovr1_bis:
        gdaltest.post_reason('fail')
        return 'fail'

    gdal.SetCacheMax(oldSize)

    ds = None
    gdal.Unlink('/vsimem/jpeg_18.jpg')

    return 'success'
コード例 #10
0
ファイル: mask.py プロジェクト: visr/gdal
def test_mask_14():

    src_ds = gdal.Open('data/byte.tif')

    assert src_ds is not None, 'Failed to open test dataset.'

    drv = gdal.GetDriverByName('GTiff')
    with gdaltest.config_option('GDAL_TIFF_INTERNAL_MASK_TO_8BIT', 'FALSE'):
        ds = drv.CreateCopy('tmp/byte_with_mask.tif', src_ds)
    src_ds = None

    # The only flag value supported for internal mask is GMF_PER_DATASET
    with gdaltest.error_handler():
        with gdaltest.config_option('GDAL_TIFF_INTERNAL_MASK', 'YES'):
            ret = ds.CreateMaskBand(0)
    assert ret != 0, 'Error expected'

    with gdaltest.config_option('GDAL_TIFF_INTERNAL_MASK', 'YES'):
        ret = ds.CreateMaskBand(gdal.GMF_PER_DATASET)
    assert ret == 0, 'Creation failed'

    cs = ds.GetRasterBand(1).GetMaskBand().Checksum()
    assert cs == 0, 'Got wrong checksum for the mask (1)'

    ds.GetRasterBand(1).GetMaskBand().Fill(1)

    cs = ds.GetRasterBand(1).GetMaskBand().Checksum()
    assert cs == 400, 'Got wrong checksum for the mask (2)'

    # This TIFF dataset has already an internal mask band
    with gdaltest.error_handler():
        with gdaltest.config_option('GDAL_TIFF_INTERNAL_MASK', 'YES'):
            ret = ds.CreateMaskBand(gdal.GMF_PER_DATASET)
    assert ret != 0, 'Error expected'

    # This TIFF dataset has already an internal mask band
    with gdaltest.error_handler():
        with gdaltest.config_option('GDAL_TIFF_INTERNAL_MASK', 'YES'):
            ret = ds.GetRasterBand(1).CreateMaskBand(gdal.GMF_PER_DATASET)
    assert ret != 0, 'Error expected'

    ds = None

    with pytest.raises(OSError,
                       message='tmp/byte_with_mask.tif.msk should not exist'):
        os.stat('tmp/byte_with_mask.tif.msk')

    with gdaltest.config_option('GDAL_TIFF_INTERNAL_MASK_TO_8BIT', 'FALSE'):
        ds = gdal.Open('tmp/byte_with_mask.tif')

        assert ds.GetRasterBand(1).GetMaskFlags() == gdal.GMF_PER_DATASET, \
          'wrong mask flags'

    cs = ds.GetRasterBand(1).GetMaskBand().Checksum()
    assert cs == 400, 'Got wrong checksum for the mask (3)'

    # Test fix for #5884
    old_val = gdal.GetCacheMax()
    gdal.SetCacheMax(0)
    with gdaltest.config_option('GDAL_TIFF_INTERNAL_MASK', 'YES'):
        out_ds = drv.CreateCopy('/vsimem/byte_with_mask.tif',
                                ds,
                                options=['COMPRESS=JPEG'])
    gdal.SetCacheMax(old_val)
    assert out_ds.GetRasterBand(1).Checksum() != 0
    cs = ds.GetRasterBand(1).GetMaskBand().Checksum()
    assert cs == 400, 'Got wrong checksum for the mask (4)'
    out_ds = None
    drv.Delete('/vsimem/byte_with_mask.tif')

    ds = None

    drv.Delete('tmp/byte_with_mask.tif')
コード例 #11
0
def _create_thumbnail(red_file,
                      green_file,
                      blue_file,
                      output_path,
                      x_constraint=None,
                      nodata=-999,
                      work_dir=None,
                      overwrite=True):
    """
    Create JPEG thumbnail image using individual R, G, B images.

    This method comes from the old ULA codebase.

    :param red_file: red band data file
    :param green_file: green band data file
    :param blue_file: blue band data file
    :param output_path: thumbnail file to write to.
    :param x_constraint: thumbnail width (if not full resolution)
    :param nodata: null/fill data value
    :param work_dir: temp/work directory to use.
    :param overwrite: overwrite existing thumbnail?

    Thumbnail height is adjusted automatically to match the aspect ratio
    of the input images.

    """
    nodata = int(nodata)

    # GDAL calls need absolute paths.
    thumbnail_path = pathlib.Path(output_path).absolute()

    if thumbnail_path.exists() and not overwrite:
        _LOG.warning('File already exists. Skipping creation of %s',
                     thumbnail_path)
        return None, None, None

    # thumbnail_image = os.path.abspath(thumbnail_image)

    out_directory = str(thumbnail_path.parent)
    work_dir = os.path.abspath(work_dir) if work_dir else tempfile.mkdtemp(
        prefix='.thumb-tmp', dir=out_directory)
    try:
        # working files
        file_to = os.path.join(work_dir, 'rgb.vrt')
        warp_to_file = os.path.join(work_dir, 'rgb-warped.vrt')
        outtif = os.path.join(work_dir, 'thumbnail.tif')

        # Build the RGB Virtual Raster at full resolution
        run_command([
            "gdalbuildvrt", "-overwrite", "-separate", file_to,
            str(red_file),
            str(green_file),
            str(blue_file)
        ], work_dir)
        assert os.path.exists(file_to), "VRT must exist"

        # Determine the pixel scaling to get the correct width thumbnail
        vrt = gdal.Open(file_to)
        intransform = vrt.GetGeoTransform()
        inpixelx = intransform[1]
        # inpixely = intransform[5]
        inrows = vrt.RasterYSize
        incols = vrt.RasterXSize

        # If a specific resolution is asked for.
        if x_constraint:
            outresx = inpixelx * incols / x_constraint
            _LOG.info('Input pixel res %r, output pixel res %r', inpixelx,
                      outresx)

            outrows = int(
                math.ceil((float(inrows) / float(incols)) * x_constraint))

            run_command([
                "gdalwarp", "--config", "GDAL_CACHEMAX",
                str(GDAL_CACHE_MAX_MB), "-of", "VRT", "-tr",
                str(outresx),
                str(outresx), "-r", "near", "-overwrite", file_to, warp_to_file
            ], work_dir)
        else:
            # Otherwise use a full resolution browse image.
            outrows = inrows
            x_constraint = incols
            warp_to_file = file_to
            outresx = inpixelx

        _LOG.debug('Current GDAL cache max %rMB. Setting to %rMB',
                   gdal.GetCacheMax() / 1024 / 1024, GDAL_CACHE_MAX_MB)
        gdal.SetCacheMax(GDAL_CACHE_MAX_MB * 1024 * 1024)

        # Open VRT file to array
        vrt = gdal.Open(warp_to_file)
        driver = gdal.GetDriverByName("GTiff")
        outdataset = driver.Create(outtif, x_constraint, outrows, 3,
                                   gdalconst.GDT_Byte)

        # Loop through bands and apply Scale and Offset
        for band_number in (1, 2, 3):
            band = vrt.GetRasterBand(band_number)

            scale, offset = _calculate_scale_offset(nodata, band)

            # Apply gain and offset
            outdataset.GetRasterBand(band_number).WriteArray(
                (numpy.ma.masked_less_equal(band.ReadAsArray(), nodata) *
                 scale) + offset)
            _LOG.debug('Scale %r, offset %r', scale, offset)

        # Must close datasets to flush to disk.
        # noinspection PyUnusedLocal
        outdataset = None
        # noinspection PyUnusedLocal
        vrt = None

        # GDAL Create doesn't support JPEG so we need to make a copy of the GeoTIFF
        run_command([
            "gdal_translate", "--config", "GDAL_CACHEMAX",
            str(GDAL_CACHE_MAX_MB), "-of", "JPEG", outtif,
            str(thumbnail_path)
        ], work_dir)

        _LOG.debug('Cleaning work files')
    finally:
        # Clean up work files
        if os.path.exists(work_dir):
            shutil.rmtree(work_dir)

    # Newer versions of GDAL create aux files due to the histogram. Clean them up.
    for f in (red_file, blue_file, green_file):
        f = pathlib.Path(f)
        aux_file = f.with_name(f.name + '.aux.xml')
        if aux_file.exists():
            _LOG.info('Cleaning aux: %s', aux_file)
            os.remove(str(aux_file.absolute()))

    return x_constraint, outrows, outresx
コード例 #12
0
def genTiles(task):
    time_start = datetime.now()
    try:
        cpu_count = 1
        if thread_count is None:
            cpu_count = multiprocessing.cpu_count()
        else:
            cpu_count = thread_count

        s_file = None
        if source_file.endswith(".zip"):
            src = source_file.replace("\\", "/")
            fileinfo = QFileInfo(src)
            filename = fileinfo.completeBaseName()
            s_file = [src, filename]
        else:
            return [0, time_start]

        #cpu_count = min(6, cpu_count)
        #if cpu_count % 2 == 0:
        #    cpu_count = int(cpu_count / 2)

        # converted_source = os.path.join(source_folder, "converted").replace("\\","/")

        QgsMessageLog.logMessage(
            'Started creating dataset from {name}'.format(name=s_file[1]),
            CATEGORY, Qgis.Info)

        source_folder = os.path.dirname(source_file)

        org_vrt = os.path.join(source_folder,
                               "OrgSource.vrt").replace("\\", "/")
        if os.path.isfile(org_vrt):
            os.remove(org_vrt)

        org_files = []
        c_file = '/vsizip/{archive}/{file}'.format(archive=s_file[0],
                                                   file=source_file_name)
        QgsMessageLog.logMessage(c_file, CATEGORY, Qgis.Info)
        org_files.append(c_file)

        ds = gdal.BuildVRT(org_vrt,
                           org_files,
                           resolution="highest",
                           resampleAlg="bilinear")
        ds.FlushCache()
        ds = None
        sleep(0.05)

        QgsMessageLog.logMessage('Created original vrt', CATEGORY, Qgis.Info)

        vrt = os.path.join(source_folder, "Source.vrt").replace("\\", "/")
        if os.path.isfile(vrt):
            os.remove(vrt)

        pds = gdal.Open(org_vrt)
        pds = gdal.Warp(vrt, pds, dstSRS="EPSG:3857")
        pds = None

        QgsMessageLog.logMessage('Created reporojected vrt', CATEGORY,
                                 Qgis.Info)

        terrarium_tile = os.path.join(source_folder,
                                      "TerrariumSource.vrt").replace(
                                          "\\", "/")
        if os.path.isfile(terrarium_tile):
            os.remove(terrarium_tile)

        statistics = []

        task.setProgress(0)

        ress = calculateStat(c_file)
        sleep(0.05)
        if ress is not None:
            statistics.append(ress)
            task.setProgress(12)
        else:
            return [0, time_start]

        ress = None

        org_files = None

        # Find minV and maxV from statistics

        minV = None
        maxV = None

        for stat in statistics:
            if stat.minV is not None and stat.maxV is not None:
                if minV is None:
                    minV = stat.minV
                elif stat.minV < minV:
                    minV = stat.minV

                if maxV is None:
                    maxV = stat.maxV
                elif stat.maxV > maxV:
                    maxV = stat.maxV

        if minV is None or maxV is None:
            QgsMessageLog.logMessage(
                'Error: Minimum and maximum height are None', CATEGORY,
                Qgis.Info)
            return None
        else:
            QgsMessageLog.logMessage(
                'Minimum and maximum height are {minv} and {maxv}'.format(
                    minv=minV, maxv=maxV), CATEGORY, Qgis.Info)
        statistics = None

        color_ramp_file = os.path.join(source_folder,
                                       "TerrariumSource.txt").replace(
                                           "\\", "/")
        if os.path.isfile(color_ramp_file):
            os.remove(color_ramp_file)
        color_ramp = []
        altitudes = []

        # Create color ramp
        for meters in range(minV, maxV):
            for fraction in range(0, 256):
                altitudes.append(meters + (fraction / 256))

        exx = concurrent.futures.ThreadPoolExecutor(max_workers=cpu_count)
        ff = exx.map(createRamp, repeat(convert_feet_to_meters), altitudes)
        for res in ff:
            if res is not None:
                color_ramp.append(res)

        QgsMessageLog.logMessage('Creating color ramp file', CATEGORY,
                                 Qgis.Info)

        #sorted_color_ramp = sorted(color_ramp, key=itemgetter(0))
        color_ramp.sort(key=lambda x: x.altitude)
        f = open(color_ramp_file, "w")
        rt = 0
        cf = len(color_ramp)
        for ramp in color_ramp:
            if manual_nodata_value is not None:
                if ramp.altitude == manual_nodata_value:
                    f.write('{altitude}\t0\t0\t0\t0\n'.format(
                        altitude=manual_nodata_value))
                    continue
            f.write('{altitude}\t{red}\t{green}\t{blue}\n'.format(
                altitude=ramp.altitude,
                red=ramp.red,
                green=ramp.green,
                blue=ramp.blue))
            rt += 1
            task.setProgress(max(0, min(int(((rt * 8) / cf) + 12), 100)))
        f.write('nv\t0\t0\t0\t0')
        f.close()
        sleep(0.05)

        QgsMessageLog.logMessage('Created color ramp file', CATEGORY,
                                 Qgis.Info)

        QgsMessageLog.logMessage(
            'Rendering vrt to terrarium format. This might take a while',
            CATEGORY, Qgis.Info)

        dst_ds = gdal.Open(vrt)

        ds = gdal.DEMProcessing(terrarium_tile,
                                dst_ds,
                                'color-relief',
                                colorFilename=color_ramp_file,
                                format="VRT",
                                addAlpha=True)
        sleep(0.05)

        dst_ds = None

        QgsMessageLog.logMessage('Created vrt in terrarium format', CATEGORY,
                                 Qgis.Info)

        input_file = terrarium_tile

        QgsMessageLog.logMessage('Input file is {inp}'.format(inp=input_file),
                                 CATEGORY, Qgis.Info)
        dst_ds = gdal.Open(input_file, gdal.GA_ReadOnly)

        QgsMessageLog.logMessage('Calculating bounds', CATEGORY, Qgis.Info)

        diff = (20037508.3427892439067364 +
                20037508.3427892439067364) / 2**zoom

        x_diff = diff
        y_diff = diff

        ulx, xres, xskew, uly, yskew, yres = dst_ds.GetGeoTransform()
        #lrx = ulx + (dst_ds.RasterXSize * xres)
        #lry = uly + (dst_ds.RasterYSize * yres)

        info = gdal.Info(dst_ds, format='json')

        ulx, uly = info['cornerCoordinates']['upperLeft'][0:2]
        lrx, lry = info['cornerCoordinates']['lowerRight'][0:2]

        wgs_minx, wgs_minz = info['wgs84Extent']['coordinates'][0][1][0:2]
        wgs_maxx, wgs_maxz = info['wgs84Extent']['coordinates'][0][3][0:2]

        QgsMessageLog.logMessage(
            'The dataset bounds are (in WGS84 [EPSG:4326]), minX: {minX}, minZ: {minZ}, maxX: {maxX}, maxZ: {maxZ}'
            .format(minX=str(wgs_minx),
                    minZ=str(wgs_minz),
                    maxX=str(wgs_maxx),
                    maxZ=str(wgs_maxz)), CATEGORY, Qgis.Info)
        QgsMessageLog.logMessage(
            'Use this info for following step F in Part two: Generating/using your dataset',
            CATEGORY, Qgis.Info)

        base_x = 0
        base_y = 0

        start_tile_x = base_x
        start_tile_y = base_y
        min_x = -20037508.3427892439067364
        max_y = 20037508.3427892439067364

        x_tiles = 0
        y_tiles = 0

        # Find min lot
        lon = min_x
        while lon <= 20037508.3427892439067364:
            if lon >= ulx:
                break
            start_tile_x += 1
            min_x = lon
            lon += x_diff

        # Find max lat
        lat = max_y
        while lat >= -20037508.3427892550826073:
            if lat <= uly:
                break
            start_tile_y += 1
            max_y = lat
            lat -= y_diff

        # Find how many lon tiles to make
        lon = min_x
        while lon < lrx:
            x_tiles += 1
            lon += x_diff

        # Find how many lat tiles to make
        lat = max_y
        while lat >= lry:
            y_tiles += 1
            lat -= y_diff

        if start_tile_x > 0:
            start_tile_x -= 1
        if start_tile_y > 0:
            start_tile_y -= 1

        QgsMessageLog.logMessage(
            'Start tile: {tx} ({mlon}), {ty} ({mxlat}), Tiles to generate: {xt} (Width: {xtx} Height: {xty})'
            .format(tx=start_tile_x,
                    mlon=min_x,
                    ty=start_tile_y,
                    mxlat=max_y,
                    xt=((x_tiles + 1) * (y_tiles + 1)),
                    xtx=x_tiles,
                    xty=y_tiles), CATEGORY, Qgis.Info)

        QgsMessageLog.logMessage('Creating output folders', CATEGORY,
                                 Qgis.Info)

        mx = start_tile_x + x_tiles
        zoom_folder = ''
        temp_folder = os.path.join(source_folder, 'TEMP')
        if not ftp_upload:
            zoom_folder = os.path.join(output_directory,
                                       str(zoom)).replace("\\", "/")
        else:
            if os.path.isdir(temp_folder):
                shutil.rmtree(temp_folder)
            os.mkdir(temp_folder)
            zoom_folder = os.path.join(temp_folder,
                                       str(zoom)).replace("\\", "/")
        if not os.path.isdir(zoom_folder):
            os.mkdir(zoom_folder)
        for x in range(start_tile_x, mx):
            folderx = os.path.join(zoom_folder, str(x)).replace("\\", "/")
            if not os.path.isdir(folderx):
                os.mkdir(folderx)

        zip_file = os.path.join(zoom_folder,
                                'RenderedDataset.zip').replace("\\", "/")
        rd_file = None
        if ftp_upload:
            if ftp_one_file:
                rd_file = zipfile.ZipFile(zip_file, 'w', zipfile.ZIP_DEFLATED)
            else:
                QgsMessageLog.logMessage(
                    'Creating output folders on target ftp server', CATEGORY,
                    Qgis.Info)
                ftp = None
                if ftp_s:
                    ftp = FTP_TLS()
                else:
                    ftp = FTP()
                ftp.connect(ftp_upload_url, ftp_upload_port)
                if ftp_user is None or ftp_password is None:
                    ftp.login()
                else:
                    ftp.login(user=ftp_user, passwd=ftp_password)

                if ftp_upload_folder is not None:
                    ftp.cwd(ftp_upload_folder)

                if not isFTPDir(ftp, str(zoom)):
                    ftp.mkd(str(zoom))

                ftp.cwd(str(zoom))

                for x in range(start_tile_x, mx):
                    ftp.mkd(str(x))

                ftp.quit()

        sleep(0.01)

        QgsMessageLog.logMessage('Created {ct} folders'.format(ct=x_tiles),
                                 CATEGORY, Qgis.Info)

        tiled = []

        # Tile dataset
        sub_min_x = min_x
        sub_max_x = min_x + x_diff
        for x in range(start_tile_x, start_tile_x + x_tiles):
            sub_min_y = max_y - y_diff
            sub_max_y = max_y
            for y in range(start_tile_y, start_tile_y + y_tiles):
                tiled.append(
                    Tile(dst_ds,
                         str(x),
                         str(y),
                         sub_min_x,
                         sub_max_y,
                         sub_max_x,
                         sub_min_y,
                         ulx,
                         xres,
                         uly,
                         yres,
                         querysize=1024))
                sub_min_y -= y_diff
                sub_max_y -= y_diff
            sub_min_x += x_diff
            sub_max_x += x_diff

        job = Job(zoom_folder, str(zoom), input_file, getBands(dst_ds),
                  resampling_algorithm, ftp_upload, ftp_one_file, ftp_s,
                  ftp_upload_url, ftp_upload_port, ftp_upload_folder, ftp_user,
                  ftp_password)

        dst_ds = None

        # Tile the dataset

        realtiles = 0

        if cpu_count == 1:
            QgsMessageLog.logMessage(
                'Started tilling vrt in singlethread mode', CATEGORY,
                Qgis.Info)
            rt = 0
            cf = len(tiled)
            for tile in tiled:
                res = tileVrt(job, tile)
                if res is not None:
                    if job.ftpUpload and job.ftpOnefile:
                        addToZIP(job, res, rd_file)
                    realtiles += 1
                rt += 1
                if job.ftpUpload and job.ftpOnefile:
                    task.setProgress(
                        max(0, min(int(((rt * 40) / cf) + 20), 100)))
                else:
                    task.setProgress(
                        max(0, min(int(((rt * 80) / cf) + 20), 100)))

            if getattr(localThread, 'ds', None):
                del localThread.ds

        else:
            QgsMessageLog.logMessage(
                'Started tilling vrt in multithread mode with {count} threads'.
                format(count=cpu_count), CATEGORY, Qgis.Info)

            gdal_cache_max = gdal.GetCacheMax()
            gdal_cache_max_per_process = max(
                1024 * 1024, math.floor(gdal_cache_max / cpu_count))
            setCacheMax(gdal_cache_max_per_process)

            tpe = concurrent.futures.ThreadPoolExecutor(max_workers=cpu_count)
            tm = tpe.map(tileVrt, repeat(job), tiled)
            rt = 0
            cf = len(tiled)
            realtiles = 0
            for res in tm:
                if res is not None:
                    if job.ftpUpload and job.ftpOnefile:
                        addToZIP(job, res, rd_file)
                    realtiles += 1
                rt += 1
                if job.ftpUpload and job.ftpOnefile:
                    task.setProgress(
                        max(0, min(int(((rt * 40) / cf) + 20), 100)))
                else:
                    task.setProgress(
                        max(0, min(int(((rt * 80) / cf) + 20), 100)))

            setCacheMax(gdal_cache_max)

        if job.ftpUpload and job.ftpOnefile:
            rd_file.close()
            totalSize = os.path.getsize(zip_file)
            QgsMessageLog.logMessage(
                'Starting uploading renderned archive ({size} GB) out of {count} tiles'
                .format(count=realtiles,
                        size=str(round(totalSize * 9.3132257461548E-10,
                                       4))), CATEGORY, Qgis.Info)

            ftp = None
            if job.ftpS:
                ftp = FTP_TLS()
            else:
                ftp = FTP()
            ftp.connect(job.ftpUrl, job.ftpPort)

            if job.ftpUser is None or job.ftpPassword is None:
                ftp.login()
            else:
                ftp.login(user=job.ftpUser, passwd=job.ftpPassword)

            if job.ftpFolder is not None:
                ftp.cwd(job.ftpFolder)

            ftapr = FTPArchiveProgress(int(totalSize), task)

            with open(zip_file, 'rb') as rdf_file:
                ftp.storbinary('STOR RenderedDataset.zip',
                               rdf_file,
                               blocksize=1024,
                               callback=ftapr.updateProgress)

            ftp.quit()

            QgsMessageLog.logMessage(
                'Uploaded rendered dataset archive RenderedDataset.zip with {count} tiles. You can now unzip it.'
                .format(count=realtiles), CATEGORY, Qgis.Info)
        elif job.ftpUpload:
            QgsMessageLog.logMessage(
                'Tiled and uploaded dataset with {count} tiles to ftp server'.
                format(count=realtiles), CATEGORY, Qgis.Info)
        else:
            QgsMessageLog.logMessage(
                'Tiled dataset with {count} tiles'.format(count=realtiles),
                CATEGORY, Qgis.Info)

        # Clean up
        if cleanup:
            os.remove(org_vrt)
            os.remove(vrt)
            os.remove(color_ramp_file)
            os.remove(terrarium_tile)

            if job.ftpUpload:
                shutil.rmtree(temp_folder)

            QgsMessageLog.logMessage('Cleaned up temp files', CATEGORY,
                                     Qgis.Info)

        return [realtiles, time_start]
    except Exception as e:
        QgsMessageLog.logMessage('Error: ' + str(e), CATEGORY, Qgis.Info)
        return None
コード例 #13
0
    def __init__(self, parent: QtWidgets.QWidget=None):
        super().__init__(parent)

        self.layout = QtWidgets.QVBoxLayout()
        self.setLayout(self.layout)

        # Top

        self.top_layout = QtWidgets.QHBoxLayout()

        self.layout.addLayout(self.top_layout)
        self.version_groupbox = QtWidgets.QGroupBox(self)
        self.version_groupbox.setTitle("Version")
        self.top_layout.addWidget(self.version_groupbox)
        self.version_layout = QtWidgets.QVBoxLayout()
        self.version_groupbox.setLayout(self.version_layout)
        self.version_release = QtWidgets.QLabel("Release: %s" % gdal.VersionInfo('RELEASE_NAME'))
        self.version_layout.addWidget(self.version_release)
        self.version_date = QtWidgets.QLabel("Date: %s" % gdal.VersionInfo('RELEASE_DATE'))
        self.version_layout.addWidget(self.version_date)

        self.cache_groupbox = QtWidgets.QGroupBox(self)
        self.cache_groupbox.setTitle("Cache")
        self.top_layout.addWidget(self.cache_groupbox)
        self.cache_layout = QtWidgets.QVBoxLayout()
        self.cache_groupbox.setLayout(self.cache_layout)
        self.cache_max = QtWidgets.QLabel("Max: %.1f MB" % (gdal.GetCacheMax() / (1024 * 1024),))
        self.cache_layout.addWidget(self.cache_max)
        self.cache_used = QtWidgets.QLabel("Used: %.1f MB" % (gdal.GetCacheUsed() / (1024 * 1024),))
        self.cache_layout.addWidget(self.cache_used)

        # Bottom

        self.bottom_layout = QtWidgets.QHBoxLayout()
        self.layout.addLayout(self.bottom_layout)

        self.drivers_groupbox = QtWidgets.QGroupBox(self)
        self.drivers_groupbox.setTitle("Drivers")
        self.bottom_layout.addWidget(self.drivers_groupbox)
        self.drivers_layout = QtWidgets.QVBoxLayout()
        self.drivers_groupbox.setLayout(self.drivers_layout)
        self.drivers_count = QtWidgets.QLabel("Number of drivers: %d" % (gdal.GetDriverCount(),))
        self.drivers_layout.addWidget(self.drivers_count)

        self.table = QtWidgets.QTableWidget(self)
        self.table.setObjectName("GdalDriversTable")
        self.table.setColumnCount(8)
        self.table.setRowCount(0)
        item = QtWidgets.QTableWidgetItem()
        item.setText("Short Name")
        self.table.setHorizontalHeaderItem(0, item)
        item = QtWidgets.QTableWidgetItem()
        item.setText("Long Name")
        self.table.setHorizontalHeaderItem(1, item)
        item = QtWidgets.QTableWidgetItem()
        item.setText("Help Page")
        self.table.setHorizontalHeaderItem(2, item)
        item = QtWidgets.QTableWidgetItem()
        item.setText("Mime Type")
        self.table.setHorizontalHeaderItem(3, item)
        item = QtWidgets.QTableWidgetItem()
        item.setText("Extensions")
        self.table.setHorizontalHeaderItem(4, item)
        item = QtWidgets.QTableWidgetItem()
        item.setText("Data Types")
        self.table.setHorizontalHeaderItem(5, item)
        item = QtWidgets.QTableWidgetItem()
        item.setText("Creation Options")
        self.table.setHorizontalHeaderItem(6, item)
        item = QtWidgets.QTableWidgetItem()
        item.setText("Metadata")
        self.table.setHorizontalHeaderItem(7, item)
        self.table.horizontalHeader().setStretchLastSection(True)
        self.table.horizontalHeader().setVisible(True)
        self.table.verticalHeader().setVisible(False)
        self.table.setRowCount(gdal.GetDriverCount())
        for row in range(gdal.GetDriverCount()):
            driver = gdal.GetDriver(row)
            self.table.setItem(row, 0, QtWidgets.QTableWidgetItem(driver.ShortName))
            self.table.setItem(row, 1, QtWidgets.QTableWidgetItem(driver.LongName))
            self.table.setItem(row, 2, QtWidgets.QTableWidgetItem(driver.HelpTopic))

            metadata = driver.GetMetadata()
            if metadata:
                self.table.setItem(row, 3, QtWidgets.QTableWidgetItem(str(metadata.pop(gdal.DMD_MIMETYPE, ''))))
                self.table.setItem(row, 4, QtWidgets.QTableWidgetItem(str(metadata.pop(gdal.DMD_EXTENSION, ''))))
                self.table.setItem(row, 5,
                                   QtWidgets.QTableWidgetItem(str(metadata.pop(gdal.DMD_CREATIONDATATYPES, ''))))

                full_data = metadata.pop(gdal.DMD_CREATIONOPTIONLIST, '')
                if full_data:
                    data = full_data[:10] + "[..]"
                else:
                    data = full_data
                table_item = QtWidgets.QTableWidgetItem(data)
                table_item.setToolTip(full_data)
                self.table.setItem(row, 6, table_item)

                metadata_list = ['%s=%s' % (k, v) for k, v in metadata.items()]
                metadata = ", ".join(metadata_list)[:10] + "[..]"
                table_item = QtWidgets.QTableWidgetItem(metadata)
                table_item.setToolTip('\n'.join(metadata_list))
                self.table.setItem(row, 7, table_item)

        self.table.horizontalHeader().resizeSections(QtWidgets.QHeaderView.ResizeToContents)
        self.table.setSortingEnabled(True)
        self.table.sortItems(0, QtCore.Qt.AscendingOrder)
        self.drivers_layout.addWidget(self.table)
コード例 #14
0
    def __init__(self, arguments):
        """Constructor function - initialization"""
        try:
            subprocess.call(["gdalbuildvrt", "--help"])
        except:
            print "gdalbuildvrt is required to run gdal2cesium in multi inumpyuts mode"
            exit(1)

        self.stopped = False
        self.multi_suffix = ''
        self.inumpyut = None
        self.default_base_output = 'tiles'
        self.min_tile_tz = None
        self.inumpyuts_data = {}
        self.inumpyuts_files_or_vrt = []
        self.vrts = {}
        self.tminmax = None
        self.zoom_resolutions = {}
        self.tminz = None
        self.tmaxz = None

        gdal.AllRegister()
        self.mem_drv = gdal.GetDriverByName('MEM')
        self.geodetic = GlobalGeodetic()

        # Tile format
        self.tilesize = 64
        self.tileext = 'terrain'

        self.epsg4326 = "EPSG:4326"

        self.tilelayer = None

        self.scaledquery = True
        # How big should be query window be for scaling down
        # Later on reset according the chosen resampling algorightm
        self.querysize = 4 * self.tilesize

        # pixel overlap between tiles according to Ceiusm heightmap format
        self.extrapixels = 0

        # RUN THE ARGUMENT PARSER:
        self.optparse_init()
        self.options, self.args = self.parser.parse_args(args=arguments)
        self.options.srcnodata = None
        if not self.args:
            self.error("No inumpyut file specified")

        # POSTPROCESSING OF PARSED ARGUMENTS:
        # Workaround for old versions of GDAL
        try:
            if (self.options.verbose and self.options.resampling
                    == 'near') or gdal.TermProgress_nocb:
                pass
        except:
            self.error(
                "This version of GDAL is not supported. Please upgrade to 1.6+."
            )
            #,"You can try run crippled version of gdal2tiles with parameters: -v -r 'near'")

        self.inumpyuts = [i for i in self.args]

        # Default values for not given options
        if self.options.output:
            self.output = self.options.output
        else:
            if len(self.inumpyuts) > 0:
                self.multi_suffix = '_multi'
            self.output = os.path.join(
                self.default_base_output,
                os.path.basename(self.inumpyuts[0]).split('.')[0] +
                self.multi_suffix)
            self.options.title = os.path.basename(self.inumpyuts[0] +
                                                  self.multi_suffix)
        self.tmpoutput = os.path.join(self.output, 'tmp')

        # Supported options
        self.resampling = None

        if self.options.resampling == 'average':
            try:
                if gdal.RegenerateOverview:
                    pass
            except:
                self.error(
                    "'average' resampling algorithm is not available.",
                    "Please use -r 'near' argument or upgrade to newer version of GDAL."
                )
        elif self.options.resampling == 'near':
            self.resampling = gdal.GRA_NearestNeighbour
            self.querysize = self.tilesize
        elif self.options.resampling == 'bilinear':
            self.resampling = gdal.GRA_Bilinear
            self.querysize = self.tilesize * 2
        elif self.options.resampling == 'cubic':
            self.resampling = gdal.GRA_Cubic
        elif self.options.resampling == 'cubicspline':
            self.resampling = gdal.GRA_CubicSpline
        elif self.options.resampling == 'lanczos':
            self.resampling = gdal.GRA_Lanczos

        # User specified zoom levels
        self.user_tminz = None
        self.user_tmaxz = None
        if self.options.zoom:
            minmax = self.options.zoom.split('-', 1)
            minmax.extend([''])
            min, max = minmax[:2]
            self.user_tminz = int(min)
            if max:
                self.user_tmaxz = int(max)
            else:
                self.user_tmaxz = int(min)

        # Output the results
        if self.options.verbose:
            print("Options:", self.options)
            print("Inumpyut:", self.inumpyuts[0] + self.multi_suffix)
            print("Output:", self.output)
            print("Cache: %s MB" % (gdal.GetCacheMax() / 1024 / 1024))
            print('')
コード例 #15
0
def main(cmdargs):
    rank = 0  # residual from MPI version, here is single-process version, rank of the process is simply 0.
    hcr_rasters = cmdargs.input_rasters
    # check the input rasters and calcuate the number of tiles (docs)
    hcr_datasets = [
        gdal.Open(rfile, gdal.GA_ReadOnly) for rfile in hcr_rasters
    ]
    hcr_raster_xsize, hcr_raster_ysize = hcr_datasets[
        0].RasterXSize, hcr_datasets[0].RasterYSize

    if not np.all([ds.RasterXSize == hcr_raster_xsize for ds in hcr_datasets]):
        raise RuntimeError("Input rasters have different X dimensions!")
    if not np.all([ds.RasterYSize == hcr_raster_ysize for ds in hcr_datasets]):
        raise RuntimeError("Input rasters have different Y dimensions!")

    hcr_bands = [
        ds.GetRasterBand(i) for ds, i in zip(hcr_datasets, cmdargs.bands)
    ]

    cv_split = cmdargs.cv_split
    hcr_test_ds = None
    hcr_test_bd = None
    if cmdargs.test is not None:
        hcr_test_ds = gdal.Open(cmdargs.test, gdal.GA_ReadOnly)
        hcr_test_bd = hcr_test_ds.GetRasterBand(1)

    tile_xsize = cmdargs.doc_tile_size
    tile_ysize = cmdargs.doc_tile_size

    ndocs_batch = cmdargs.n_batch_docs

    # Set up dict of class codes within raster to common class codes/names across raster
    class_code2vocab = []
    for csvfname in cmdargs.class2vocab:
        arr = np.loadtxt(csvfname, dtype=int, delimiter=',')
        class_code2vocab.append({row[0]: row[1] for row in arr})
    class_errmat = None
    use_errmat = False
    if cmdargs.error_matrix is not None:
        class_errmat = []
        for csvfname in cmdargs.error_matrix:
            tmp = pd.read_csv(csvfname, index_col=0, header=0)
            tmp.columns = [int(val) for val in tmp.columns]
            class_errmat.append(tmp)
        use_errmat = True

    if cmdargs.N_factor is None:
        if class_errmat is not None:
            tmpval = np.min(
                [np.min(em.values[em.values != 0]) for em in class_errmat])
            tmpval = 10**(len(str(int(1. / tmpval))) + 1)
            N_factor = 1000 if 1000 > tmpval else tmpval
        else:
            N_factor = 1000
    else:
        N_factor = cmdargs.N_factor

    if cmdargs.doc_topic_prior is None:
        doc_topic_prior = 1. / cmdargs.n_topics
    else:
        doc_topic_prior = cmdargs.doc_topic_prior
    vocab = set(
        list(itertools.chain(*[c2v.values() for c2v in class_code2vocab])))
    if cmdargs.topic_word_prior is None:
        topic_word_prior = 1. / len(vocab)
    else:
        topic_word_prior = cmdargs.topic_word_prior

    vocab_creation = cmdargs.vocab_creation

    hcr = HarmonizeClassRasters(class_code2vocab,
                                class_errmat,
                                vocab_creation=vocab_creation,
                                n_components=cmdargs.n_topics,
                                max_iter=1000,
                                evaluate_every=1,
                                perp_tol=1e-1,
                                n_jobs=cmdargs.n_jobs,
                                batch_size=ndocs_batch,
                                doc_topic_prior=doc_topic_prior,
                                topic_word_prior=topic_word_prior)

    # Set up a look-up table (LUT) to store the occurrence counts, LDA scores,
    # topic prob. estimates of all the combinations of input class labels from
    # different rasters.
    #
    # MultiIndex: class legends of input rasters
    # Columns: occurrence counts, LDA scores, topic probs., primary topic ID
    # (the topic with the largest prob.)
    index_list = [set(c2v.values()) for c2v in class_code2vocab]
    hcr_lut_index = pd.MultiIndex.from_product(index_list)
    nrows = len(hcr_lut_index)
    ncols = 1 + 1 + 1 + 1 + cmdargs.n_topics
    prob_topic_colnames = [
        "prob_topic_{0:d}".format(i + 1) for i in range(cmdargs.n_topics)
    ]
    hcr_lut = pd.DataFrame(
        np.zeros((nrows, ncols)),
        index=hcr_lut_index,
        columns=["total_npix", "test_npix", "lda_score", "primary_topic_id"] +
        prob_topic_colnames)
    hcr_lut_values = hcr_lut.values.copy()

    ntiles_x = np.ceil(hcr_raster_xsize / tile_xsize).astype(np.int)
    ntiles_y = np.ceil(hcr_raster_ysize / tile_ysize).astype(np.int)
    dw_mat = np.zeros((ndocs_batch, len(hcr.vocab)))
    word_count = np.zeros(len(hcr.vocab))
    doc_idx = 0
    tmp = np.argmax([
        np.dtype(gdal_array.GDALTypeCodeToNumericTypeCode(
            bd.DataType)).itemsize for bd in hcr_bands
    ])
    img_dtype = gdal_array.GDALTypeCodeToNumericTypeCode(
        hcr_bands[tmp].DataType)
    nodata = np.iinfo(img_dtype).max

    #    prf = cProfile.Profile()
    #    prf.enable()
    ndigits = max(len(str(ntiles_x)), len(str(ntiles_y)))
    progress_tot = ntiles_x * ntiles_y
    progress_pct = 10
    progress_frc = int(progress_pct / 100. * progress_tot)
    if progress_frc == 0:
        progress_frc = 1
        progress_pct = int(progress_frc / float(progress_tot) * 100)

    progress_cnt = 0
    progress_npct = 0
    if cv_split is not None:
        logger.info(
            "Process {0:d}: Search non-empty tiles for cross-validation sampling ..."
            .format(rank))
        valid_tiles = []
        for iby in range(ntiles_y):
            for ibx in range(ntiles_x):
                xoff, yoff = tile_xsize * ibx, tile_ysize * iby
                win_xsize = tile_xsize if ibx < ntiles_x - 1 else hcr_raster_xsize - xoff
                win_ysize = tile_ysize if iby < ntiles_y - 1 else hcr_raster_ysize - yoff

                mb_img = [
                    bd.ReadAsArray(xoff, yoff, win_xsize,
                                   win_ysize).astype(img_dtype)
                    for bd in hcr_bands
                ]
                valid_flag = False
                for ib, img in enumerate(mb_img):
                    tmp = set(hcr.class_code2vocab[ib].index.values)
                    tmp_len = len(tmp)
                    if len(tmp - set(np.unique(img))) < tmp_len:
                        valid_flag = True
                        break
                if valid_flag:
                    valid_tiles.append(
                        np.ravel_multi_index((iby, ibx), (ntiles_y, ntiles_x)))
                progress_cnt += 1
                if progress_cnt % progress_frc == 0:
                    progress_npct += progress_pct
                    if progress_npct <= 100:
                        logger.info(
                            "Process {1:d}: Finish searching non-empty tiles {0:d}%"
                            .format(progress_npct, rank))
        logger.info(
            "Process {1:d}: Finish searching non-empty tiles, {0:d} non-empty tiles found"
            .format(len(valid_tiles), rank))

        # Random sampling for each CV test
        cv_test_tiles = []
        cv_hcr_list = []
        cv_dw_mat_list = [np.zeros((ndocs_batch, len(hcr.vocab)))] * cv_split
        cv_doc_idx_list = [0] * cv_split
        cv_score_sum_list = [0] * cv_split
        cv_perplexity_list = [0] * cv_split
        cv_size = int(len(valid_tiles) / cv_split)
        for i in range(cv_split):
            tmp = np.random.choice(valid_tiles, size=cv_size, replace=False)
            tmprow, tmpcol = np.unravel_index(tmp, (ntiles_y, ntiles_x))
            tmp = spspa.coo_matrix(
                (np.ones_like(tmprow, dtype=np.bool), (tmprow, tmpcol)),
                shape=(ntiles_y, ntiles_x),
                dtype=np.bool)
            cv_test_tiles.append(tmp.tocsr())
            cv_hcr_list.append(
                HarmonizeClassRasters(class_code2vocab,
                                      class_errmat,
                                      vocab_creation=vocab_creation,
                                      n_components=cmdargs.n_topics,
                                      max_iter=1000,
                                      evaluate_every=1,
                                      perp_tol=1e-1,
                                      n_jobs=cmdargs.n_jobs,
                                      batch_size=ndocs_batch,
                                      doc_topic_prior=doc_topic_prior,
                                      topic_word_prior=topic_word_prior))

    progress_cnt = 0
    progress_npct = 0
    logger.info(
        "Process {0:d}: Start building document-word matrix from input classification rasters ..."
        .format(rank))
    for iby in range(ntiles_y):
        for ibx in range(ntiles_x):
            xoff, yoff = tile_xsize * ibx, tile_ysize * iby
            win_xsize = tile_xsize if ibx < ntiles_x - 1 else hcr_raster_xsize - xoff
            win_ysize = tile_ysize if iby < ntiles_y - 1 else hcr_raster_ysize - yoff

            mb_img = [
                bd.ReadAsArray(xoff, yoff, win_xsize,
                               win_ysize).astype(img_dtype) for bd in hcr_bands
            ]
            # mask out invalid pixels
            img_mask = np.zeros_like(mb_img[0], dtype=np.bool)
            for ib, img in enumerate(mb_img):
                tmp_mask = np.ones_like(img, dtype=np.bool)
                for v in hcr.class_code2vocab[ib].index:
                    tmp_mask = np.logical_and(tmp_mask, img != v)
                img_mask = np.logical_or(img_mask, tmp_mask)

            tmp_mask = np.logical_not(img_mask)
            to_do_flag = tmp_mask.sum() > 0
            if to_do_flag:
                uq_idx, uq_cnt = np.unique(np.array([
                    hcr._translateArray(img[tmp_mask], hcr.class_code2vocab[i])
                    for i, img in enumerate(mb_img)
                ]),
                                           axis=1,
                                           return_counts=True)
                hcr_lut.loc[[*zip(*uq_idx.tolist())], "total_npix"] += uq_cnt

            if to_do_flag and hcr_test_bd is not None:
                test_mask = hcr_test_bd.ReadAsArray(xoff, yoff, win_xsize,
                                                    win_ysize)
                # img_mask = np.logical_or(img_mask, test_mask==1)
                tmp_mask = np.logical_not(
                    np.logical_or(img_mask, test_mask == 1))
                # to_do_flag = tmp_mask.sum() > 0
                if tmp_mask.sum() > 0:  # to_do_flag:
                    uq_idx, uq_cnt = np.unique(np.array([
                        hcr._translateArray(img[tmp_mask],
                                            hcr.class_code2vocab[i])
                        for i, img in enumerate(mb_img)
                    ]),
                                               axis=1,
                                               return_counts=True)
                    hcr_lut.loc[[*zip(*uq_idx.tolist())],
                                "test_npix"] += uq_cnt

            if to_do_flag:
                for img in mb_img:
                    img[img_mask] = nodata
                dw_mat[doc_idx, :] = hcr.genDocWordFromArray(
                    np.dstack(mb_img),
                    use_errmat=use_errmat,
                    N_factor=N_factor)
                if cv_split is not None:
                    for i in range(cv_split):
                        if not cv_test_tiles[i][iby, ibx]:
                            cv_dw_mat_list[i][cv_doc_idx_list[i], :] = dw_mat[
                                doc_idx, :]
                            cv_doc_idx_list[i] += 1
                doc_idx += 1

            if doc_idx == ndocs_batch:
                word_count += np.sum(dw_mat[0:doc_idx, :], axis=0)
                hcr.fitTopicModel(dw_mat[0:doc_idx, :])
                doc_idx = 0
            if cv_split is not None:
                for i in range(cv_split):
                    if cv_doc_idx_list[i] == ndocs_batch:
                        cv_hcr_list[i].fitTopicModel(
                            cv_dw_mat_list[i][0:cv_doc_idx_list[i], :])
                        cv_doc_idx_list[i] = 0

            progress_cnt += 1
            if progress_cnt % progress_frc == 0:
                progress_npct += progress_pct
                if progress_npct <= 100:
                    logger.info(
                        "Process {1:d}: Finish reading input rasters and building document-word matrix {0:d}%"
                        .format(progress_npct, rank))

#        if doc_idx > ndocs_batch-2:
#            break

    if doc_idx > 0:
        word_count += np.sum(dw_mat[0:doc_idx, :], axis=0)
        hcr.fitTopicModel(dw_mat[0:doc_idx, :])
        doc_idx = 0
    tmp = np.where(word_count == 0)[0]
    if len(tmp) > 0:
        logger.warning(
            "Process {0:d}: Some classes never appeared in the input rasters: \n"
            .format(rank) + str(np.array(list(hcr.vocab))[tmp]))
    if cv_split is not None:
        for i in range(cv_split):
            if cv_doc_idx_list[i] > 0:
                cv_hcr_list[i].fitTopicModel(
                    cv_dw_mat_list[i][0:cv_doc_idx_list[i], :])
                cv_doc_idx_list[i] = 0

#    prf.disable()
#    prf.print_stats(sort="time")

# The following only on master node, no parallelization.
    logger.info("Process {0:d}: Finish fitting LDA model ...".format(rank))

    # Estimate the topic probs. for all the combinations of the input class
    # labels and their LDA scores.
    logger.info(
        "Process {0:d}: Start building LUT of the topics and LDA score per combination of input class legends ..."
        .format(rank))
    #         prf = cProfile.Profile()
    #         prf.enable()
    dw_mat = np.array([
        hcr.genDocWordFromArray(np.array(idx)[np.newaxis, np.newaxis, :],
                                use_errmat=use_errmat,
                                N_factor=N_factor) for idx in hcr_lut.index
    ])
    dt_dist = hcr.estDocTopicDist(dw_mat)
    hcr_lut.loc[:, prob_topic_colnames] = dt_dist
    hcr_lut.loc[:, "primary_topic_id"] = np.argmax(dt_dist, axis=1) + 1
    for i, idx in enumerate(hcr_lut.index):
        hcr_lut.loc[idx, "lda_score"] = hcr.lda.score(dw_mat[[i], :])
#         prf.disable()
#         prf.print_stats(sort="time")
    logger.info(
        "Process {0:d}: Finish building LUT of the topics and LDA score per combination of input class legends ..."
        .format(rank))

    if hcr_test_bd is not None:
        # Calculate the perplexity and score over the test pixels
        logger.info(
            "Process {0:d}: Start estimating perplexity and score over test pixels ..."
            .format(rank))
        #         prf = cProfile.Profile()
        #         prf.enable()
        score_sum = np.sum(hcr_lut["lda_score"] * hcr_lut["test_npix"])
        perplexity = np.exp(
            -1 * np.sum(hcr_lut["lda_score"] * hcr_lut["test_npix"]) /
            (N_factor * np.sum(hcr_lut["test_npix"])))
    else:
        score_sum = np.nan
        perplexity = np.nan

    if cv_split is not None:
        logger.info(
            "Process {0:d}: Start estimating perplexity and score over test tiles in the cross validation ..."
            .format(rank))
        progress_npct = 0
        progress_cnt = 0
        for iby in range(ntiles_y):
            for ibx in range(ntiles_x):
                if np.any(
                    [cv_test_tiles[i][iby, ibx] for i in range(cv_split)]):
                    xoff, yoff = tile_xsize * ibx, tile_ysize * iby
                    win_xsize = tile_xsize if ibx < ntiles_x - 1 else hcr_raster_xsize - xoff
                    win_ysize = tile_ysize if iby < ntiles_y - 1 else hcr_raster_ysize - yoff

                    mb_img = [
                        bd.ReadAsArray(xoff, yoff, win_xsize,
                                       win_ysize).astype(img_dtype)
                        for bd in hcr_bands
                    ]
                    # mask out invalid pixels
                    img_mask = np.zeros_like(mb_img[0], dtype=np.bool)
                    for ib, img in enumerate(mb_img):
                        tmp_mask = np.ones_like(img, dtype=np.bool)
                        for v in hcr.class_code2vocab[ib].index:
                            tmp_mask = np.logical_and(tmp_mask, img != v)
                        img_mask = np.logical_or(img_mask, tmp_mask)

                    tmp_mask = np.logical_not(img_mask)
                    to_do_flag = tmp_mask.sum() > 0
                    if to_do_flag:
                        for img in mb_img:
                            img[img_mask] = nodata
                        dw_mat = hcr.genDocWordFromArray(np.dstack(mb_img),
                                                         use_errmat=use_errmat,
                                                         N_factor=N_factor)
                        for i in range(cv_split):
                            if cv_test_tiles[i][iby, ibx]:
                                cv_score_sum_list[i] += cv_hcr_list[
                                    i].lda.score(dw_mat[np.newaxis, :])
                                cv_perplexity_list[i] += np.log(
                                    cv_hcr_list[i].lda.perplexity(
                                        dw_mat[np.newaxis, :]))
                progress_cnt += 1
                if progress_cnt % progress_frc == 0:
                    progress_npct += progress_pct
                    if progress_npct <= 100:
                        logger.info(
                            "Process {1:d}: Finish inference of test tiles in cross validation {0:d}%"
                            .format(progress_npct, rank))
        logger.info(
            "Process {4:d}: doc_size = {0:d}, n_topics = {1:d}, n_factor = {2:d}, cv_size = {3:d}, cross validation report: "
            .format(cmdargs.doc_tile_size, cmdargs.n_topics, N_factor, cv_size,
                    rank))
        logger.info(
            "Process {0:d}: cv_seq, perplexity_cv, score_cv".format(rank))
        for i in range(cv_split):
            logger.info("Process {3:d}: {0:d}, {1:e}, {2:e}".format(
                i, np.exp(cv_perplexity_list[i] / cv_size),
                cv_score_sum_list[i] / cv_size, rank))
        logger.info("Process {3:d}: {0:s}, {1:e}, {2:e}".format(
            "mean", np.exp(np.mean(cv_perplexity_list) / cv_size),
            np.mean(cv_score_sum_list) / cv_size, rank))


#         prf.disable()
#         prf.print_stats(sort="time")

    logger.info(
        "Process {5:d}: doc_size = {2:d}, n_topics = {3:d}, n_factor = {4:d}, perplexity_test_pixels = {0:e}, score_test_pixels = {1:e}"
        .format(perplexity, score_sum, cmdargs.doc_tile_size, cmdargs.n_topics,
                N_factor, rank))

    joblib.dump(hcr.lda, cmdargs.out_model)
    logger.info("Process {1:d}: trained LDA model saved to {0:s}".format(
        cmdargs.out_model, rank))
    vocab_joblib = ".".join(
        cmdargs.out_model.split('.')[0:-1]) + "_vocab.joblib"
    joblib.dump(hcr.vocab, vocab_joblib)
    logger.info(
        "Process {1:d}: vocabulary list of trained LDA model saved to {0:s}".
        format(vocab_joblib, rank))

    if cmdargs.out_topic_word is not None:
        pd.DataFrame(hcr.getTopicWordDist(),
                     columns=hcr._dw.index).to_csv(cmdargs.out_topic_word)
        logger.info(
            "Process {1:d}: Topic-word distribution written to {0:s}".format(
                cmdargs.out_topic_word, rank))
    if cmdargs.out_lut is not None:
        hcr_lut.to_csv(cmdargs.out_lut)
        logger.info(
            "Process {1:d}: Look-up table of LDA model written to {0:s}".
            format(cmdargs.out_lut, rank))

    if cmdargs.out_class is not None:
        class_raster = cmdargs.out_class
        class_format = cmdargs.out_format
        prob_raster = cmdargs.out_prob
        prob_format = cmdargs.out_format

        class_nodata = 0
        prob_nodata = 0

        if hcr.lda.n_components < np.iinfo(np.int8).max:
            out_type = np.int8
        elif hcr.lda.n_components < np.iinfo(np.int16).max:
            out_type = np.int16
        elif hcr.lda.n_components < np.iinfo(np.int32).max:
            out_type = np.int32
        elif hcr.lda.n_components < np.iinfo(np.int64).max:
            out_type = np.int64
        else:
            out_type = np.float
        prob_type = np.float32

        block_xsize, block_ysize = hcr_bands[0].GetBlockSize()
        hcr_type_size = np.max([
            np.dtype(gdal_array.GDALTypeCodeToNumericTypeCode(
                bd.DataType)).itemsize for bd in hcr_bands
        ])
        tmpn = int(gdal.GetCacheMax() /
                   (block_xsize * block_ysize * hcr_type_size))
        if tmpn > int(hcr_raster_ysize / block_ysize):
            tmpn = int(hcr_raster_ysize / block_ysize)
        if tmpn > 1:
            block_ysize = tmpn * block_ysize

        nblocks_x, nblocks_y = np.ceil(
            hcr_raster_xsize / block_xsize).astype(int), np.ceil(
                hcr_raster_ysize / block_ysize).astype(int)

        block_meta_data = np.zeros(4, dtype=np.int)
        logger.info(
            "Process {0:d}: Start estimating and writing pixel topics (harmonized class) ..."
            .format(rank))
        logger.info(
            "Process {4:d}: n_blocks_x = {0:d}, n_blocks_y = {1:d}, block_xsize = {2:d}, block_ysize = {3:d}"
            .format(nblocks_x, nblocks_y, block_xsize, block_ysize, rank))

        # Use N-dimensional array to speed up LUT search
        hcr_lut_class_arr = np.zeros(
            [len(ilevel) for ilevel in hcr_lut.index.levels], dtype=out_type)
        hcr_lut_prob_arr = np.zeros(
            [len(ilevel)
             for ilevel in hcr_lut.index.levels] + [len(prob_topic_colnames)],
            dtype=prob_type)
        # Put the LUT into the array
        hcr_lut_class_arr[tuple(
            hcr_lut.index.labels)] = hcr_lut.loc[:, "primary_topic_id"]
        for i, coln in enumerate(prob_topic_colnames):
            hcr_lut_prob_arr[tuple(hcr_lut.index.labels +
                                   [[i] *
                                    len(hcr_lut.index)])] = hcr_lut.loc[:,
                                                                        coln]

        class_driver = gdal.GetDriverByName(class_format)

        class_ds = class_driver.Create(
            class_raster, hcr_raster_xsize, hcr_raster_ysize, 1,
            gdal_array.NumericTypeCodeToGDALTypeCode(out_type))
        class_bd = class_ds.GetRasterBand(1)

        if prob_format is not None:
            prob_driver = gdal.GetDriverByName(prob_format)
        else:
            prob_driver = class_driver

        if prob_raster is not None:
            prob_ds = prob_driver.Create(
                prob_raster, hcr_raster_xsize, hcr_raster_ysize,
                hcr.lda.n_components,
                gdal_array.NumericTypeCodeToGDALTypeCode(prob_type))
        else:
            prob_ds = None

        progress_cnt = 0
        progress_tot = nblocks_x * nblocks_y
        progress_pct = 10
        progress_frc = int(progress_pct / 100. * progress_tot)
        if progress_frc == 0:
            progress_frc = 1
        progress_npct = 0
        for iby in range(nblocks_y):
            for ibx in range(nblocks_x):
                xoff, yoff = ibx * block_xsize, iby * block_ysize
                # win_xsize, win_ysize = hcr_bands[0].GetActualBlockSize(ibx, iby)
                win_xsize = block_xsize if ibx < nblocks_x - 1 else hcr_raster_xsize - xoff
                win_ysize = block_ysize if iby < nblocks_y - 1 else hcr_raster_ysize - yoff

                block_meta_data[:] = [xoff, yoff, win_xsize, win_ysize]
                # On slave processes,
                mb_img = [
                    bd.ReadAsArray(xoff, yoff, win_xsize, win_ysize)
                    for bd in hcr_bands
                ]
                # mask out invalid pixels
                img_mask = np.zeros_like(mb_img[0], dtype=np.bool)
                for ib, img in enumerate(mb_img):
                    tmp_mask = np.ones_like(img, dtype=np.bool)
                    for v in hcr.class_code2vocab[ib].index:
                        tmp_mask = np.logical_and(tmp_mask, img != v)
                    img_mask = np.logical_or(img_mask, tmp_mask)
                img_mask = np.logical_not(img_mask)

                to_do_flag = np.sum(img_mask) > 0
                class_img = np.zeros_like(img_mask,
                                          dtype=out_type) + class_nodata
                if to_do_flag:
                    # Convert image of vocabulary class labels to image of indexes to these classes in each input raster
                    idx_img_list = []
                    for i, img in enumerate(mb_img):
                        tmp_idx = hcr._translateArray(img[img_mask],
                                                      hcr.class_code2vocab[i])
                        for idx, ilevel in enumerate(hcr_lut.index.levels[i]):
                            tmp_idx[tmp_idx == ilevel] = idx
                        idx_img_list.append(tmp_idx)
                    class_img[img_mask] = hcr_lut_class_arr[tuple(
                        idx_img_list)]

                if prob_raster is not None:
                    prob_img_list = [
                        np.zeros_like(img_mask, dtype=prob_type) + prob_nodata
                        for coln in prob_topic_colnames
                    ]
                    if to_do_flag:
                        for i, coln in enumerate(prob_topic_colnames):
                            prob_img_list[i][img_mask] = hcr_lut_prob_arr[
                                tuple(idx_img_list +
                                      [[i] * len(idx_img_list[0])])]
                    prob_img = np.dstack(prob_img_list)

                class_bd.WriteArray(class_img, int(block_meta_data[0]),
                                    int(block_meta_data[1]))
                if prob_raster is not None:
                    for i in range(prob_img.shape[2]):
                        prob_ds.GetRasterBand(i + 1).WriteArray(
                            prob_img[:, :, i], int(block_meta_data[0]),
                            int(block_meta_data[1]))

                progress_cnt += 1
                if progress_cnt % progress_frc == 0:
                    if progress_frc == 1:
                        progress_npct = int(100 * progress_cnt / progress_tot)
                    else:
                        progress_npct += progress_pct
                    if progress_npct <= 100:
                        logger.info(
                            "Process {1:d}: Finish pixel inference and writing {0:d}%"
                            .format(progress_npct, rank))

        class_ds.FlushCache()
        prob_ds.FlushCache()

        class_ds.SetGeoTransform(hcr_datasets[0].GetGeoTransform())
        class_ds.SetProjection(hcr_datasets[0].GetProjectionRef())
        prob_ds.SetGeoTransform(hcr_datasets[0].GetGeoTransform())
        prob_ds.SetProjection(hcr_datasets[0].GetProjectionRef())

        class_ds = None
        prob_ds = None

    # On both master and slave processes, close the raster files that GDAL has
    # opened.
    for i in range(len(hcr_datasets)):
        hcr_datasets[i] = None
    if cmdargs.test is not None:
        hcr_test_ds = None
コード例 #16
0
def mask_14():

    src_ds = gdal.Open('data/byte.tif')

    if src_ds is None:
        gdaltest.post_reason('Failed to open test dataset.')
        return 'fail'

    drv = gdal.GetDriverByName('GTiff')
    gdal.SetConfigOption('GDAL_TIFF_INTERNAL_MASK_TO_8BIT', 'FALSE')
    ds = drv.CreateCopy('tmp/byte_with_mask.tif', src_ds)
    gdal.SetConfigOption('GDAL_TIFF_INTERNAL_MASK_TO_8BIT', 'TRUE')
    src_ds = None

    # The only flag value supported for internal mask is GMF_PER_DATASET
    with gdaltest.error_handler():
        with gdaltest.config_option('GDAL_TIFF_INTERNAL_MASK', 'YES'):
            ret = ds.CreateMaskBand(0)
    if ret == 0:
        gdaltest.post_reason('Error expected')
        return 'fail'

    with gdaltest.config_option('GDAL_TIFF_INTERNAL_MASK', 'YES'):
        ret = ds.CreateMaskBand(gdal.GMF_PER_DATASET)
    if ret != 0:
        gdaltest.post_reason('Creation failed')
        return 'fail'

    cs = ds.GetRasterBand(1).GetMaskBand().Checksum()
    if cs != 0:
        print(cs)
        gdaltest.post_reason('Got wrong checksum for the mask (1)')
        return 'fail'

    ds.GetRasterBand(1).GetMaskBand().Fill(1)

    cs = ds.GetRasterBand(1).GetMaskBand().Checksum()
    if cs != 400:
        print(cs)
        gdaltest.post_reason('Got wrong checksum for the mask (2)')
        return 'fail'

    # This TIFF dataset has already an internal mask band
    with gdaltest.error_handler():
        with gdaltest.config_option('GDAL_TIFF_INTERNAL_MASK', 'YES'):
            ret = ds.CreateMaskBand(gdal.GMF_PER_DATASET)
    if ret == 0:
        gdaltest.post_reason('Error expected')
        return 'fail'

    # This TIFF dataset has already an internal mask band
    with gdaltest.error_handler():
        with gdaltest.config_option('GDAL_TIFF_INTERNAL_MASK', 'YES'):
            ret = ds.GetRasterBand(1).CreateMaskBand(gdal.GMF_PER_DATASET)
    if ret == 0:
        gdaltest.post_reason('Error expected')
        return 'fail'

    ds = None

    try:
        os.stat('tmp/byte_with_mask.tif.msk')
        gdaltest.post_reason('tmp/byte_with_mask.tif.msk should not exist')
        return 'fail'
    except:
        pass

    gdal.SetConfigOption('GDAL_TIFF_INTERNAL_MASK_TO_8BIT', 'FALSE')
    ds = gdal.Open('tmp/byte_with_mask.tif')

    if ds.GetRasterBand(1).GetMaskFlags() != gdal.GMF_PER_DATASET:
        gdaltest.post_reason('wrong mask flags')
        return 'fail'

    gdal.SetConfigOption('GDAL_TIFF_INTERNAL_MASK_TO_8BIT', 'TRUE')

    cs = ds.GetRasterBand(1).GetMaskBand().Checksum()
    if cs != 400:
        print(cs)
        gdaltest.post_reason('Got wrong checksum for the mask (3)')
        return 'fail'

    # Test fix for #5884
    gdal.SetConfigOption('GDAL_TIFF_INTERNAL_MASK', 'YES')
    old_val = gdal.GetCacheMax()
    gdal.SetCacheMax(0)
    out_ds = drv.CreateCopy('/vsimem/byte_with_mask.tif',
                            ds,
                            options=['COMPRESS=JPEG'])
    gdal.SetConfigOption('GDAL_TIFF_INTERNAL_MASK', None)
    gdal.SetCacheMax(old_val)
    if out_ds.GetRasterBand(1).Checksum() == 0:
        gdaltest.post_reason('failure')
        return 'fail'
    cs = ds.GetRasterBand(1).GetMaskBand().Checksum()
    if cs != 400:
        print(cs)
        gdaltest.post_reason('Got wrong checksum for the mask (4)')
        return 'fail'
    out_ds = None
    drv.Delete('/vsimem/byte_with_mask.tif')

    ds = None

    drv.Delete('tmp/byte_with_mask.tif')

    return 'success'
コード例 #17
0
	def __init__(self, arguments ):
		"""Constructor function - initialization"""
		
		self.stopped = False
		self.input = None
		self.output = None

		# Tile format
		self.tilesize = 256
		self.tiledriver = 'PNG'
		self.tileext = 'png'
		
		# Should we read bigger window of the input raster and scale it down?
		# Note: Modified leter by open_input()
		# Not for 'near' resampling
		# Not for Wavelet based drivers (JPEG2000, ECW, MrSID)
		# Not for 'raster' profile
		self.scaledquery = True
		# How big should be query window be for scaling down
		# Later on reset according the chosen resampling algorightm
		self.querysize = 4 * self.tilesize

		# Should we use Read on the input file for generating overview tiles?
		# Note: Modified later by open_input()
		# Otherwise the overview tiles are generated from existing underlying tiles
		self.overviewquery = False
		
		# RUN THE ARGUMENT PARSER:
		
		self.optparse_init()
		self.options, self.args = self.parser.parse_args(args=arguments)
		if not self.args:
			self.error("No input file specified")

		# POSTPROCESSING OF PARSED ARGUMENTS:

		# Workaround for old versions of GDAL
		try:
			if (self.options.verbose and self.options.resampling == 'near') or gdal.TermProgress_nocb:
				pass
		except:
			self.error("This version of GDAL is not supported. Please upgrade to 1.6+.")
			#,"You can try run crippled version of gdal2tiles with parameters: -v -r 'near'")
		
		# Is output directory the last argument?

		# Overwrite output, default to 'input.sqlitedb'
		self.output=self.args[-1]+'.sqlitedb'
		self.store = SqliteTileStorage('TMS')
		self.store.create(self.output,True) 

		# More files on the input not directly supported yet
		
		if (len(self.args) > 1):
			self.error("Processing of several input files is not supported.",
			"""Please first use a tool like gdal_vrtmerge.py or gdal_merge.py on the files:
gdal_vrtmerge.py -o merged.vrt %s""" % " ".join(self.args))
			# TODO: Call functions from gdal_vrtmerge.py directly
			
		self.input = self.args[0]
		
		# Supported options

		self.resampling = 'antialias'
		
		# User specified zoom levels
		self.tminz = None
		self.tmaxz = None
		if self.options.zoom:
			minmax = self.options.zoom.split('-',1)
			minmax.extend([''])
			min, max = minmax[:2]
			self.tminz = int(min)
			if max:
				self.tmaxz = int(max)
			else:
				self.tmaxz = int(min) 
		

		# Output the results

		if self.options.verbose:
			print("Options:", self.options)
			print("Input:", self.input)
			print("Output:", self.output)
			print("Cache: %s MB" % (gdal.GetCacheMax() / 1024 / 1024))
			print('')