Esempio n. 1
0
def setTable(imgFile, colorTable, bandNumber=1):
    """
    Given either an open gdal dataset, or a filename,
    sets the color table as an array.
    
    The colorTable is given as an array of shape (4, numEntries)
    where numEntries is the size of the color table. The order of indices
    in the first axis is:

        * Red
        * Green
        * Blue
        * Opacity
        
    The Red/Green/Blue values are on the range 0-255, with 255 meaning full 
    color, and the opacity is in the range 0-255, with 255 meaning fully 
    opaque. 
    
    This table is useually generated by getTable() or genTable().
    
    """    
    if isinstance(imgFile, basestring):
        ds = gdal.Open(str(imgFile), gdal.GA_Update)
    elif isinstance(imgFile, gdal.Dataset):
        ds = imgFile

    gdalBand = ds.GetRasterBand(bandNumber)
    attrTbl = gdalBand.GetDefaultRAT()
    if attrTbl is None:
        # some formats eg ENVI return None
        # here so we need to be able to cope
        attrTbl = gdal.RasterAttributeTable()
        isFileRAT = False
    else:

        isFileRAT = True

        # but if it doesn't support dynamic writing
        # we still ahve to call SetDefaultRAT
        if not attrTbl.ChangesAreWrittenToFile():
            isFileRAT = False
            
    ncols, numEntries = colorTable.shape
    attrTbl.SetRowCount(numEntries)
    
    # set the columns based on their usage, creating
    # if necessary
    colorUsages = {gdal.GFU_Red : 'Red', gdal.GFU_Green : 'Green', 
        gdal.GFU_Blue : 'Blue', gdal.GFU_Alpha : 'Alpha'}
    for idx, usage in enumerate(colorUsages):
        colNum = attrTbl.GetColOfUsage(usage)
        if colNum == -1:
            name = colorUsages[usage]
            attrTbl.CreateColumn(name, gdal.GFT_Integer, usage)
            colNum = attrTbl.GetColumnCount() - 1
            
        attrTbl.WriteArray(colorTable[idx], colNum)
        
    if not isFileRAT:
        attrTbl.SetDefaultRAT(attrTbl)
Esempio n. 2
0
def test_hfa_unique_values_hist():

    try:
        gdal.RasterAttributeTable()
    except:
        pytest.skip()

    ds = gdal.Open('data/i8u_c_i.img')

    md = ds.GetRasterBand(1).GetMetadata()

    expected = '12603|1|0|0|45|1|0|0|0|0|656|177|0|0|5026|1062|0|0|2|0|0|0|0|0|0|0|0|0|0|0|0|0|75|1|0|0|207|158|0|0|8|34|0|0|0|0|538|57|0|10|214|20|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|1|31|0|0|9|625|67|0|0|118|738|117|3004|1499|491|187|1272|513|1|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|16|3|0|0|283|123|5|1931|835|357|332|944|451|80|40|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|12|5|0|0|535|1029|118|0|33|246|342|0|0|10|8|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|169|439|0|0|6|990|329|0|0|120|295|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|164|42|0|0|570|966|0|0|18|152|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|45|106|0|0|16|16517|'
    assert md[
        'STATISTICS_HISTOBINVALUES'] == expected, 'Unexpected HISTOBINVALUES.'

    assert md['STATISTICS_HISTOMIN'] == '0' and md['STATISTICS_HISTOMAX'] == '255', \
        "unexpected histomin/histomax value."

    # lets also check the RAT to ensure it has the BinValues column added.

    rat = ds.GetRasterBand(1).GetDefaultRAT()

    assert rat.GetColumnCount() == 6 and rat.GetTypeOfCol(0) == gdal.GFT_Real and rat.GetUsageOfCol(0) == gdal.GFU_MinMax, \
        'BinValues column wrong.'

    assert rat.GetValueAsInt(2, 0) == 4, 'BinValues value wrong.'

    rat = None

    ds = None
Esempio n. 3
0
File: mem.py Progetto: kongdd/gdal
def test_mem_rat():

    ds = gdal.GetDriverByName('MEM').Create('', 1, 1)
    ds.GetRasterBand(1).SetDefaultRAT(gdal.RasterAttributeTable())
    assert ds.GetRasterBand(1).GetDefaultRAT() is not None
    ds.GetRasterBand(1).SetDefaultRAT(None)
    assert ds.GetRasterBand(1).GetDefaultRAT() is None
Esempio n. 4
0
def test_ehdr_rat():

    tmpfile = '/vsimem/rat.bil'
    gdal.Translate(tmpfile, 'data/int16_rat.bil', format='EHdr')
    ds = gdal.Open(tmpfile)
    rat = ds.GetRasterBand(1).GetDefaultRAT()
    assert rat is not None
    assert rat.GetColumnCount() == 4
    assert rat.GetRowCount() == 25
    for (idx, val) in [(0, -500), (1, 127), (2, 40), (3, 65)]:
        assert rat.GetValueAsInt(0, idx) == val
    for (idx, val) in [(0, 2000), (1, 145), (2, 97), (3, 47)]:
        assert rat.GetValueAsInt(24, idx) == val
    assert ds.GetRasterBand(1).GetColorTable() is not None
    ds = None

    ds = gdal.Open(tmpfile, gdal.GA_Update)
    ds.GetRasterBand(1).SetDefaultRAT(None)
    ds.GetRasterBand(1).SetColorTable(None)
    ds = None

    ds = gdal.Open(tmpfile, gdal.GA_Update)
    assert not (ds.GetRasterBand(1).GetDefaultRAT() or ds.GetRasterBand(1).GetColorTable())
    with gdaltest.error_handler():
        ret = ds.GetRasterBand(1).SetDefaultRAT(gdal.RasterAttributeTable())
    assert ret != 0
    ds = None

    gdal.GetDriverByName('EHDR').Delete(tmpfile)
def create_rat(in_raster, lookup, band_number=1):
    """
    Create simple raster attribute table based on lookup {int: string} dict
    Output RAT columns: VALUE (integer), DESCRIPTION (string)
    eg: lookup = {1: "URBAN", 5: "WATER", 11: "AGRICULTURE", 16: "MINING"}
    https://gis.stackexchange.com/questions/333897/read-rat-raster-attribute-table-using-gdal-or-other-python-libraries
    """
    # open the raster at band
    raster = gdal.Open(in_raster, gdal.GA_Update)
    band = raster.GetRasterBand(band_number)

    # Create and populate the RAT
    rat = gdal.RasterAttributeTable()
    rat.CreateColumn("VALUE", gdal.GFT_Integer, gdal.GFU_Generic)
    rat.CreateColumn("DESCRIPTION", gdal.GFT_String, gdal.GFU_Generic)

    i = 0
    for value, description in sorted(lookup.items()):
        rat.SetValueAsInt(i, 0, int(value))
        rat.SetValueAsString(i, 1, str(description))
        i += 1

    raster.FlushCache()
    band.SetDefaultRAT(rat)
    raster = None
    rat = None
    band = None
Esempio n. 6
0
def get_rat_from_vat(filename):
    md = ogr.Open(filename)
    mdl = md.GetLayer(0)
    # get column definitions:
    rat = gdal.RasterAttributeTable()
    # use skip to adjust column index
    layer_defn = mdl.GetLayerDefn()
    for field_idx in range(0, layer_defn.GetFieldCount()):
        field_defn = layer_defn.GetFieldDefn(field_idx)
        field_type = TYPE_MAP[field_defn.GetType()]
        if field_type is None:
            # skip unmappable field type
            continue
        rat.CreateColumn(field_defn.GetName(), field_type,
                         USAGE_MAP[field_defn.GetName()])
    for feature_idx in range(0, mdl.GetFeatureCount()):
        feature = mdl.GetFeature(feature_idx)
        skip = 0
        for field_idx in range(0, feature.GetFieldCount()):
            field_type = TYPE_MAP[feature.GetFieldType(field_idx)]
            if field_type == gdal.GFT_Integer:
                rat.SetValueAsInt(feature_idx, field_idx - skip,
                                  feature.GetFieldAsInteger(field_idx))
            elif field_type == gdal.GFT_Real:
                rat.SetValueAsDouble(feature_idx, field_idx - skip,
                                     feature.GetFieldAsDouble(field_idx))
            elif field_type == gdal.GFT_String:
                rat.SetValueAsString(feature_idx, field_idx - skip,
                                     feature.GetFieldAsString(field_idx))
            else:
                # skip all unmappable field types
                skip += 1
    return rat
Esempio n. 7
0
def test_rat_3():

    ds = gdal.GetDriverByName('GTiff').Create('/vsimem/rat_3.tif', 1, 1)
    ds.GetRasterBand(1).SetDefaultRAT(gdal.RasterAttributeTable())
    ds = None

    gdal.GetDriverByName('GTiff').Delete('/vsimem/rat_3.tif')
Esempio n. 8
0
File: vrtmisc.py Progetto: ahhz/gdal
def vrtmisc_rat():

    ds = gdal.Translate('/vsimem/vrtmisc_rat.tif',
                        'data/byte.tif',
                        format='MEM')
    rat = gdal.RasterAttributeTable()
    rat.CreateColumn("Ints", gdal.GFT_Integer, gdal.GFU_Generic)
    ds.GetRasterBand(1).SetDefaultRAT(rat)

    vrt_ds = gdal.GetDriverByName('VRT').CreateCopy('/vsimem/vrtmisc_rat.vrt',
                                                    ds)

    xml_vrt = vrt_ds.GetMetadata('xml:VRT')[0]
    if gdal.GetLastErrorMsg() != '':
        gdaltest.post_reason('fail')
        return 'fail'
    vrt_ds = None

    if xml_vrt.find('<GDALRasterAttributeTable>') < 0:
        gdaltest.post_reason('fail')
        print(xml_vrt)
        return 'fail'

    vrt_ds = gdal.Translate('/vsimem/vrtmisc_rat.vrt',
                            ds,
                            format='VRT',
                            srcWin=[0, 0, 1, 1])

    xml_vrt = vrt_ds.GetMetadata('xml:VRT')[0]
    if gdal.GetLastErrorMsg() != '':
        gdaltest.post_reason('fail')
        return 'fail'
    vrt_ds = None

    if xml_vrt.find('<GDALRasterAttributeTable>') < 0:
        gdaltest.post_reason('fail')
        print(xml_vrt)
        return 'fail'

    ds = None

    vrt_ds = gdal.Open('/vsimem/vrtmisc_rat.vrt', gdal.GA_Update)
    rat = vrt_ds.GetRasterBand(1).GetDefaultRAT()
    if rat is None or rat.GetColumnCount() != 1:
        gdaltest.post_reason('fail')
        return 'fail'
    vrt_ds.GetRasterBand(1).SetDefaultRAT(None)
    if vrt_ds.GetRasterBand(1).GetDefaultRAT() is not None:
        gdaltest.post_reason('fail')
        return 'fail'
    vrt_ds = None

    ds = None

    gdal.Unlink('/vsimem/vrtmisc_rat.vrt')
    gdal.Unlink('/vsimem/vrtmisc_rat.tif')

    return "success"
Esempio n. 9
0
def rat_1():

    gdaltest.saved_rat = None

    try:
        rat = gdal.RasterAttributeTable()
    except:
        return 'skip'

    rat.CreateColumn('Value', gdal.GFT_Integer, gdal.GFU_MinMax)
    rat.CreateColumn('Count', gdal.GFT_Integer, gdal.GFU_PixelCount)

    rat.SetRowCount(3)
    rat.SetValueAsInt(0, 0, 10)
    rat.SetValueAsInt(0, 1, 100)
    rat.SetValueAsInt(1, 0, 11)
    rat.SetValueAsInt(1, 1, 200)
    rat.SetValueAsInt(2, 0, 12)
    rat.SetValueAsInt(2, 1, 90)

    rat2 = rat.Clone()

    if rat2.GetColumnCount() != 2:
        gdaltest.post_reason('wrong column count')
        return 'fail'

    if rat2.GetRowCount() != 3:
        gdaltest.post_reason('wrong row count')
        return 'fail'

    if rat2.GetNameOfCol(1) != 'Count':
        gdaltest.post_reason('wrong column name')
        return 'fail'

    if rat2.GetUsageOfCol(1) != gdal.GFU_PixelCount:
        gdaltest.post_reason('wrong column usage')
        return 'fail'

    if rat2.GetTypeOfCol(1) != gdal.GFT_Integer:
        gdaltest.post_reason('wrong column type')
        return 'fail'

    if rat2.GetRowOfValue(11.0) != 1:
        gdaltest.post_reason('wrong row for value')
        return 'fail'

    if rat2.GetValueAsInt(1, 1) != 200:
        gdaltest.post_reason('wrong field value.')
        return 'fail'

    gdaltest.saved_rat = rat

    return 'success'
Esempio n. 10
0
def ehdr_rat():

    tmpfile = '/vsimem/rat.bil'
    gdal.Translate(tmpfile, 'data/int16_rat.bil', format='EHdr')
    ds = gdal.Open(tmpfile)
    rat = ds.GetRasterBand(1).GetDefaultRAT()
    if rat is None:
        gdaltest.post_reason('fail')
        return 'fail'
    if rat.GetColumnCount() != 4:
        gdaltest.post_reason('fail')
        print(rat.GetColumnCount())
        return 'fail'
    if rat.GetRowCount() != 25:
        gdaltest.post_reason('fail')
        print(rat.GetRowCount())
        return 'fail'
    for (idx, val) in [(0, -500), (1, 127), (2, 40), (3, 65)]:
        if rat.GetValueAsInt(0, idx) != val:
            gdaltest.post_reason('fail')
            print(idx, rat.GetValueAsInt(0, idx))
            return 'fail'
    for (idx, val) in [(0, 2000), (1, 145), (2, 97), (3, 47)]:
        if rat.GetValueAsInt(24, idx) != val:
            gdaltest.post_reason('fail')
            print(idx, rat.GetValueAsInt(24, idx))
            return 'fail'
    if ds.GetRasterBand(1).GetColorTable() is None:
        gdaltest.post_reason('fail')
        return 'fail'
    ds = None

    ds = gdal.Open(tmpfile, gdal.GA_Update)
    ds.GetRasterBand(1).SetDefaultRAT(None)
    ds.GetRasterBand(1).SetColorTable(None)
    ds = None

    ds = gdal.Open(tmpfile, gdal.GA_Update)
    if ds.GetRasterBand(1).GetDefaultRAT() or ds.GetRasterBand(
            1).GetColorTable():
        gdaltest.post_reason('fail')
        return 'fail'
    with gdaltest.error_handler():
        ret = ds.GetRasterBand(1).SetDefaultRAT(gdal.RasterAttributeTable())
    if ret == 0:
        gdaltest.post_reason('fail')
        return 'fail'
    ds = None

    gdal.GetDriverByName('EHDR').Delete(tmpfile)

    return 'success'
Esempio n. 11
0
def mem_rat():

    ds = gdal.GetDriverByName('MEM').Create('', 1, 1)
    ds.GetRasterBand(1).SetDefaultRAT(gdal.RasterAttributeTable())
    if ds.GetRasterBand(1).GetDefaultRAT() is None:
        gdaltest.post_reason('fail')
        return 'fail'
    ds.GetRasterBand(1).SetDefaultRAT(None)
    if ds.GetRasterBand(1).GetDefaultRAT() is not None:
        gdaltest.post_reason('fail')
        return 'fail'

    return 'success'
Esempio n. 12
0
def test_rat_4():

    # Create test RAT
    ds = gdal.GetDriverByName('GTiff').Create('/vsimem/rat_4.tif', 1, 1)
    rat = gdal.RasterAttributeTable()
    rat.CreateColumn('VALUE', gdal.GFT_Integer, gdal.GFU_MinMax)
    rat.CreateColumn('CLASS', gdal.GFT_String, gdal.GFU_Name)
    rat.SetValueAsInt(0, 0, 111)
    rat.SetValueAsString(0, 1, 'Class1')
    ds.GetRasterBand(1).SetDefaultRAT(rat)
    ds = None

    # Verify
    ds = gdal.OpenEx('/vsimem/rat_4.tif')
    gdal_band = ds.GetRasterBand(1)
    rat = gdal_band.GetDefaultRAT()
    assert rat.GetValueAsInt(0, 0) == 111
    ds = None

    # Replace existing RAT
    rat = gdal.RasterAttributeTable()
    rat.CreateColumn('VALUE', gdal.GFT_Integer, gdal.GFU_MinMax)
    rat.CreateColumn('CLASS', gdal.GFT_String, gdal.GFU_Name)
    rat.SetValueAsInt(0, 0, 222)
    rat.SetValueAsString(0, 1, 'Class1')
    ds = gdal.OpenEx('/vsimem/rat_4.tif', gdal.OF_RASTER | gdal.OF_UPDATE)
    gdal_band = ds.GetRasterBand(1)
    gdal_band.SetDefaultRAT(rat)
    ds = None

    # Verify
    ds = gdal.OpenEx('/vsimem/rat_4.tif')
    gdal_band = ds.GetRasterBand(1)
    rat = gdal_band.GetDefaultRAT()
    assert rat is not None
    assert rat.GetValueAsInt(0, 0) == 222
    ds = None

    gdal.GetDriverByName('GTiff').Delete('/vsimem/rat_4.tif')
Esempio n. 13
0
def df_to_gdal_rat(df):
    df = df.copy()
    if 'ClassNumber' not in df.columns:
        df['ClassNumber'] = df.index

    rat = gdal.RasterAttributeTable()
    rat.SetRowCount(len(df))
    for num, col in enumerate(df.columns):
        gftype = dtype_map(df[col].dtype)
        if col in f_names:
            usetype = f_use_d[col]
        else:
            usetype = gdal.GFU_Generic
        # have to call str(col) because CreateColumn can't take unicode
        rat.CreateColumn(str(col), gftype, usetype)
        rat.WriteArray(df[col].tolist(), num)
    return rat
Esempio n. 14
0
File: hfa.py Progetto: garnertb/gdal
def hfa_unique_values_hist():

    try:
        gdal.RasterAttributeTable()
    except:
        return 'skip'

    ds = gdal.Open('data/i8u_c_i.img')

    md = ds.GetRasterBand(1).GetMetadata()

    expected = '12603|1|0|0|45|1|0|0|0|0|656|177|0|0|5026|1062|0|0|2|0|0|0|0|0|0|0|0|0|0|0|0|0|75|1|0|0|207|158|0|0|8|34|0|0|0|0|538|57|0|10|214|20|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|1|31|0|0|9|625|67|0|0|118|738|117|3004|1499|491|187|1272|513|1|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|16|3|0|0|283|123|5|1931|835|357|332|944|451|80|40|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|12|5|0|0|535|1029|118|0|33|246|342|0|0|10|8|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|169|439|0|0|6|990|329|0|0|120|295|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|164|42|0|0|570|966|0|0|18|152|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|0|45|106|0|0|16|16517|'
    if md['STATISTICS_HISTOBINVALUES'] != expected:
        print(md['STATISTICS_HISTOBINVALUES'])
        gdaltest.post_reason('Unexpected HISTOBINVALUES.')
        return 'fail'

    if md['STATISTICS_HISTOMIN'] != '0' \
       or md['STATISTICS_HISTOMAX'] != '255':
        print(md)
        gdaltest.post_reason("unexpected histomin/histomax value.")
        return 'fail'

    # lets also check the RAT to ensure it has the BinValues column added.

    rat = ds.GetRasterBand(1).GetDefaultRAT()

    if rat.GetColumnCount() != 6 \
       or rat.GetTypeOfCol(0) != gdal.GFT_Real \
       or rat.GetUsageOfCol(0) != gdal.GFU_MinMax:
        print(rat.GetColumnCount())
        print(rat.GetTypeOfCol(0))
        print(rat.GetUsageOfCol(0))
        gdaltest.post_reason('BinValues column wrong.')
        return 'fail'

    if rat.GetValueAsInt(2, 0) != 4:
        print(rat.GetValueAsInt(2, 0))
        gdaltest.post_reason('BinValues value wrong.')
        return 'fail'

    rat = None

    ds = None

    return 'success'
Esempio n. 15
0
def test_vrtmisc_rat():

    ds = gdal.Translate('/vsimem/vrtmisc_rat.tif',
                        'data/byte.tif',
                        format='MEM')
    rat = gdal.RasterAttributeTable()
    rat.CreateColumn("Ints", gdal.GFT_Integer, gdal.GFU_Generic)
    ds.GetRasterBand(1).SetDefaultRAT(rat)

    vrt_ds = gdal.GetDriverByName('VRT').CreateCopy('/vsimem/vrtmisc_rat.vrt',
                                                    ds)

    xml_vrt = vrt_ds.GetMetadata('xml:VRT')[0]
    assert gdal.GetLastErrorMsg() == ''
    vrt_ds = None

    assert '<GDALRasterAttributeTable tableType="thematic">' in xml_vrt

    vrt_ds = gdal.Translate('/vsimem/vrtmisc_rat.vrt',
                            ds,
                            format='VRT',
                            srcWin=[0, 0, 1, 1])

    xml_vrt = vrt_ds.GetMetadata('xml:VRT')[0]
    assert gdal.GetLastErrorMsg() == ''
    vrt_ds = None

    assert '<GDALRasterAttributeTable tableType="thematic">' in xml_vrt

    ds = None

    vrt_ds = gdal.Open('/vsimem/vrtmisc_rat.vrt', gdal.GA_Update)
    rat = vrt_ds.GetRasterBand(1).GetDefaultRAT()
    assert rat is not None and rat.GetColumnCount() == 1
    vrt_ds.GetRasterBand(1).SetDefaultRAT(None)
    assert vrt_ds.GetRasterBand(1).GetDefaultRAT() is None
    vrt_ds = None

    ds = None

    gdal.Unlink('/vsimem/vrtmisc_rat.vrt')
    gdal.Unlink('/vsimem/vrtmisc_rat.tif')
Esempio n. 16
0
def main(rin, data):
    rat = gdal.RasterAttributeTable()
    rat.GetRowCount()
    rat.GetColumnCount()

    # safe_makedirs(os.path.dirname(rout))
    data_ext = os.path.splitext(data)[1]

    if data_ext == '.dbf':
        dbf_process(rat, data)
    elif data_ext == '.csv':
        csv_process(rat, data)
    else:
        raise Exception('Datafile must be DBF or CSV')

    driver = gdal.GetDriverByName("GTiff")
    ds_in = gdal.Open(rin)
    # ds_out = driver.CreateCopy(rout, ds_in, options=["TILED=YES", "COMPRESS=LZW"])
    ds_in.GetRasterBand(1).SetDefaultRAT(rat)
    ds_in = None
    # ds_out = None
    print('done')
Esempio n. 17
0
def test_rat_1():

    gdaltest.saved_rat = None

    try:
        rat = gdal.RasterAttributeTable()
    except:
        pytest.skip()

    rat.CreateColumn('Value', gdal.GFT_Integer, gdal.GFU_MinMax)
    rat.CreateColumn('Count', gdal.GFT_Integer, gdal.GFU_PixelCount)

    rat.SetRowCount(3)
    rat.SetValueAsInt(0, 0, 10)
    rat.SetValueAsInt(0, 1, 100)
    rat.SetValueAsInt(1, 0, 11)
    rat.SetValueAsInt(1, 1, 200)
    rat.SetValueAsInt(2, 0, 12)
    rat.SetValueAsInt(2, 1, 90)

    rat2 = rat.Clone()

    assert rat2.GetColumnCount() == 2, 'wrong column count'

    assert rat2.GetRowCount() == 3, 'wrong row count'

    assert rat2.GetNameOfCol(1) == 'Count', 'wrong column name'

    assert rat2.GetUsageOfCol(1) == gdal.GFU_PixelCount, 'wrong column usage'

    assert rat2.GetTypeOfCol(1) == gdal.GFT_Integer, 'wrong column type'

    assert rat2.GetRowOfValue(11.0) == 1, 'wrong row for value'

    assert rat2.GetValueAsInt(1, 1) == 200, 'wrong field value.'

    gdaltest.saved_rat = rat
Esempio n. 18
0
def createRAT():
    ds = gdal.Open(r'D:\cag_poultry\test\data\result\test_review5.tif')
    rb = ds.GetRasterBand(1)
    #print (rb)
    u = numpy.unique(rb.ReadAsArray())
    #print (len(u))

    #print (u.size)
    r = numpy.random.uniform(0,1000, size=u.size)

    rat = gdal.RasterAttributeTable()
    rat.CreateColumn("Value", gdal.GFT_Real, gdal.GFU_Generic)
    rat.CreateColumn("RANDOM", gdal.GFT_Real, gdal.GFU_Generic)
    for i in range(u.size):
        #print (float(r[i]))
        #print (float(u[i]))
        rat.SetValueAsDouble(i,0,float(u[i]))
        rat.SetValueAsDouble(i,1,float(r[i]))
    rb.SetDefaultRAT(rat)




    ds = None
Esempio n. 19
0
from osgeo import gdal

# Don't forget to change the folder.
os.chdir(r'D:\osgeopy-data\Switzerland')

# Open the output from listing 9.3 and get the band.
ds = gdal.Open('dem_class2.tif')
band = ds.GetRasterBand(1)

# Change the NoData value to -1 so that the histogram will be computed
# using 0 values.
band.SetNoDataValue(-1)

# Create the raster attribute table and add 3 columns for the pixel value,
# number of pixels with that value, and elevation label.
rat = gdal.RasterAttributeTable()
rat.CreateColumn('Value', gdal.GFT_Integer, gdal.GFU_Name)
rat.CreateColumn('Count', gdal.GFT_Integer, gdal.GFU_PixelCount)
rat.CreateColumn('Elevation', gdal.GFT_String, gdal.GFU_Generic)

# Add 6 rows to the table, for values 0-5.
rat.SetRowCount(6)

# Write the values 0-5 (using range) to the first column (pixel value).
rat.WriteArray(range(6), 0)

# Get the histogram and write the results to the second column (count).
rat.WriteArray(band.GetHistogram(-0.5, 5.5, 6, False, False), 1)

# Add the labels for each pixel value to the third column.
rat.SetValueAsString(1, 2, '0 - 800')
Esempio n. 20
0
def _gdal_api_proxy_sub():

    src_ds = gdal.Open('data/byte.tif')
    src_cs = src_ds.GetRasterBand(1).Checksum()
    src_gt = src_ds.GetGeoTransform()
    src_prj = src_ds.GetProjectionRef()
    src_data = src_ds.ReadRaster(0, 0, 20, 20)
    src_md = src_ds.GetMetadata()
    src_ds = None

    drv = gdal.IdentifyDriver('data/byte.tif')
    assert drv.GetDescription() == 'API_PROXY'

    ds = gdal.GetDriverByName('GTiff').Create('tmp/byte.tif', 1, 1, 3)
    ds = None

    src_ds = gdal.Open('data/byte.tif')
    ds = gdal.GetDriverByName('GTiff').CreateCopy('tmp/byte.tif',
                                                  src_ds,
                                                  options=['TILED=YES'])
    got_cs = ds.GetRasterBand(1).Checksum()
    assert src_cs == got_cs
    ds = None

    ds = gdal.Open('tmp/byte.tif', gdal.GA_Update)

    ds.SetGeoTransform([1, 2, 3, 4, 5, 6])
    got_gt = ds.GetGeoTransform()
    assert src_gt != got_gt

    ds.SetGeoTransform(src_gt)
    got_gt = ds.GetGeoTransform()
    assert src_gt == got_gt

    assert ds.GetGCPCount() == 0

    assert ds.GetGCPProjection() == '', ds.GetGCPProjection()

    assert not ds.GetGCPs()

    gcps = [gdal.GCP(0, 1, 2, 3, 4)]
    sr = osr.SpatialReference()
    sr.ImportFromEPSG(4326)
    wkt = sr.ExportToWkt()
    assert ds.SetGCPs(gcps, wkt) == 0

    got_gcps = ds.GetGCPs()
    assert len(got_gcps) == 1

    assert (got_gcps[0].GCPLine == gcps[0].GCPLine and  \
       got_gcps[0].GCPPixel == gcps[0].GCPPixel and  \
       got_gcps[0].GCPX == gcps[0].GCPX and \
       got_gcps[0].GCPY == gcps[0].GCPY)

    assert ds.GetGCPProjection() == wkt

    ds.SetGCPs([], "")

    assert not ds.GetGCPs()

    ds.SetProjection('')
    got_prj = ds.GetProjectionRef()
    assert src_prj != got_prj

    ds.SetProjection(src_prj)
    got_prj = ds.GetProjectionRef()
    assert src_prj == got_prj

    ds.GetRasterBand(1).Fill(0)
    got_cs = ds.GetRasterBand(1).Checksum()
    assert got_cs == 0

    ds.GetRasterBand(1).WriteRaster(0, 0, 20, 20, src_data)
    got_cs = ds.GetRasterBand(1).Checksum()
    assert src_cs == got_cs

    ds.GetRasterBand(1).Fill(0)
    got_cs = ds.GetRasterBand(1).Checksum()
    assert got_cs == 0

    ds.WriteRaster(0, 0, 20, 20, src_data)
    got_cs = ds.GetRasterBand(1).Checksum()
    assert src_cs == got_cs

    # Not bound to SWIG
    # ds.AdviseRead(0,0,20,20,20,20)

    got_data = ds.ReadRaster(0, 0, 20, 20)
    assert src_data == got_data

    got_data = ds.GetRasterBand(1).ReadRaster(0, 0, 20, 20)
    assert src_data == got_data

    got_data_weird_spacing = ds.ReadRaster(0,
                                           0,
                                           20,
                                           20,
                                           buf_pixel_space=1,
                                           buf_line_space=32)
    assert len(got_data_weird_spacing) == 32 * (20 - 1) + 20

    assert got_data[20:20 + 20] == got_data_weird_spacing[32:32 + 20]

    got_data_weird_spacing = ds.GetRasterBand(1).ReadRaster(0,
                                                            0,
                                                            20,
                                                            20,
                                                            buf_pixel_space=1,
                                                            buf_line_space=32)
    assert len(got_data_weird_spacing) == 32 * (20 - 1) + 20

    assert got_data[20:20 + 20] == got_data_weird_spacing[32:32 + 20]

    got_block = ds.GetRasterBand(1).ReadBlock(0, 0)
    assert len(got_block) == 256 * 256

    assert got_data[20:20 + 20] == got_block[256:256 + 20]

    ds.FlushCache()
    ds.GetRasterBand(1).FlushCache()

    got_data = ds.GetRasterBand(1).ReadRaster(0, 0, 20, 20)
    assert src_data == got_data

    assert len(ds.GetFileList()) == 1

    assert ds.AddBand(gdal.GDT_Byte) != 0

    got_md = ds.GetMetadata()
    assert src_md == got_md

    assert ds.GetMetadataItem('AREA_OR_POINT') == 'Area'

    assert ds.GetMetadataItem('foo') is None

    ds.SetMetadataItem('foo', 'bar')
    assert ds.GetMetadataItem('foo') == 'bar'

    ds.SetMetadata({'foo': 'baz'}, 'OTHER')
    assert ds.GetMetadataItem('foo', 'OTHER') == 'baz'

    ds.GetRasterBand(1).SetMetadata({'foo': 'baw'}, 'OTHER')
    assert ds.GetRasterBand(1).GetMetadataItem('foo', 'OTHER') == 'baw'

    assert ds.GetMetadataItem('INTERLEAVE', 'IMAGE_STRUCTURE') == 'BAND'

    assert not ds.GetRasterBand(1).GetMetadata()

    assert ds.GetRasterBand(1).GetMetadataItem('foo') is None

    ds.GetRasterBand(1).SetMetadataItem('foo', 'baz')
    assert ds.GetRasterBand(1).GetMetadataItem('foo') == 'baz'

    ds.GetRasterBand(1).SetMetadata({'foo': 'baw'})
    assert ds.GetRasterBand(1).GetMetadataItem('foo') == 'baw'

    assert ds.GetRasterBand(1).GetColorInterpretation() == gdal.GCI_GrayIndex

    ds.GetRasterBand(1).SetColorInterpretation(gdal.GCI_Undefined)

    ct = ds.GetRasterBand(1).GetColorTable()
    assert ct is None

    ct = gdal.ColorTable()
    ct.SetColorEntry(0, (1, 2, 3))
    assert ds.GetRasterBand(1).SetColorTable(ct) == 0

    ct = ds.GetRasterBand(1).GetColorTable()
    assert ct is not None
    assert ct.GetColorEntry(0) == (1, 2, 3, 255)

    ct = ds.GetRasterBand(1).GetColorTable()
    assert ct is not None

    assert ds.GetRasterBand(1).SetColorTable(None) == 0

    ct = ds.GetRasterBand(1).GetColorTable()
    assert ct is None

    rat = ds.GetRasterBand(1).GetDefaultRAT()
    assert rat is None

    assert ds.GetRasterBand(1).SetDefaultRAT(None) == 0

    ref_rat = gdal.RasterAttributeTable()
    assert ds.GetRasterBand(1).SetDefaultRAT(ref_rat) == 0

    rat = ds.GetRasterBand(1).GetDefaultRAT()
    assert rat is None

    assert ds.GetRasterBand(1).SetDefaultRAT(None) == 0

    rat = ds.GetRasterBand(1).GetDefaultRAT()
    assert rat is None

    assert ds.GetRasterBand(1).GetMinimum() is None

    got_stats = ds.GetRasterBand(1).GetStatistics(0, 0)
    assert got_stats[3] < 0.0

    got_stats = ds.GetRasterBand(1).GetStatistics(1, 1)
    assert got_stats[0] == 74.0

    assert ds.GetRasterBand(1).GetMinimum() == 74.0

    assert ds.GetRasterBand(1).GetMaximum() == 255.0

    ds.GetRasterBand(1).SetStatistics(1, 2, 3, 4)
    got_stats = ds.GetRasterBand(1).GetStatistics(1, 1)
    assert got_stats == [1, 2, 3, 4]

    ds.GetRasterBand(1).ComputeStatistics(0)
    got_stats = ds.GetRasterBand(1).GetStatistics(1, 1)
    assert got_stats[0] == 74.0

    minmax = ds.GetRasterBand(1).ComputeRasterMinMax()
    assert minmax == (74.0, 255.0)

    assert ds.GetRasterBand(1).GetOffset() == 0.0

    assert ds.GetRasterBand(1).GetScale() == 1.0

    ds.GetRasterBand(1).SetOffset(10.0)
    assert ds.GetRasterBand(1).GetOffset() == 10.0

    ds.GetRasterBand(1).SetScale(2.0)
    assert ds.GetRasterBand(1).GetScale() == 2.0

    ds.BuildOverviews('NEAR', [2])
    assert ds.GetRasterBand(1).GetOverviewCount() == 1

    assert ds.GetRasterBand(1).GetOverview(-1) is None

    assert ds.GetRasterBand(1).GetOverview(0) is not None

    assert ds.GetRasterBand(1).GetOverview(0) is not None

    got_hist = ds.GetRasterBand(1).GetHistogram()
    assert len(got_hist) == 256

    (minval, maxval, nitems,
     got_hist2) = ds.GetRasterBand(1).GetDefaultHistogram()
    assert minval == -0.5
    assert maxval == 255.5
    assert nitems == 256
    assert got_hist == got_hist2

    ds.GetRasterBand(1).SetDefaultHistogram(1, 2, [3])
    (minval, maxval, nitems,
     got_hist3) = ds.GetRasterBand(1).GetDefaultHistogram()
    assert minval == 1
    assert maxval == 2
    assert nitems == 1
    assert got_hist3[0] == 3

    got_nodatavalue = ds.GetRasterBand(1).GetNoDataValue()
    assert got_nodatavalue is None

    ds.GetRasterBand(1).SetNoDataValue(123)
    got_nodatavalue = ds.GetRasterBand(1).GetNoDataValue()
    assert got_nodatavalue == 123

    assert ds.GetRasterBand(1).GetMaskFlags() == 8

    assert ds.GetRasterBand(1).GetMaskBand() is not None

    ret = ds.GetRasterBand(1).DeleteNoDataValue()
    assert ret == 0
    got_nodatavalue = ds.GetRasterBand(1).GetNoDataValue()
    assert got_nodatavalue is None

    ds.CreateMaskBand(0)

    assert ds.GetRasterBand(1).GetMaskFlags() == 2

    assert ds.GetRasterBand(1).GetMaskBand() is not None

    ds.GetRasterBand(1).CreateMaskBand(0)

    assert ds.GetRasterBand(1).HasArbitraryOverviews() == 0

    ds.GetRasterBand(1).SetUnitType('foo')
    assert ds.GetRasterBand(1).GetUnitType() == 'foo'

    assert ds.GetRasterBand(1).GetCategoryNames() is None

    ds.GetRasterBand(1).SetCategoryNames(['foo'])
    assert ds.GetRasterBand(1).GetCategoryNames() == ['foo']

    ds.GetRasterBand(1).SetDescription('bar')

    ds = None

    gdal.GetDriverByName('GTiff').Delete('tmp/byte.tif')
    def findBurnScars(self,
                      bp_image,
                      seed_prob_thresh=97.5,
                      seed_size_thresh=5,
                      flood_fill_prob_thresh=75,
                      log_handler=None):
        """Identify the seeds for burn scars from the input burn probabilities.
        Description: routine to find burn scars using the flood-fill approach.
          Seed pixels are found by using the seed threshold.  Any pixels with
          a probability higher than the seed threshold are identified as seed
          pixels.  Any area with more than the seed size threshold is used as
          a seed area for growing the burn extent using the flood fill
          process.  Areas with fewer than the seed size are ignored and not
          flagged as burn areas.
        
        History:
          Created in 2013 by Jodi Riegle and Todd Hawbaker, USGS Rocky Mountain
              Geographic Science Center
          Updated on Nov. 26, 2013 by Gail Schmidt, USGS/EROS LSRD Project
              Modified to use int32 arrays vs. int64 arrays for the burn
              classification images
          Updated on Feb. 13, 2015 by Gail Schmidt, USGS/EROS LSRD Project
              Modified the skimage.measure.regionprops call to not use the
              deprecated 'properties' parameter.  Also changed the properties
              values to match the correct names of the dynamic list of props
              which is now created.
        
        Args:
          bp_image - input image of burn probabilities
          seed_prob_thresh - threshold to be used to identify burn pixels
              in the burn probability image as seed pixels for the burn area;
              default is 97.5%
          seed_size_thresh - threshold to be used to identify burn areas in the
              probability image. If the count of burn probability seed pixels is
              greater than this threshold for a certain area then the burn area
              is left and flood-filled.  If the count is less than the seed
              threshold, then this is a false positive and the area is not used
              in the burn classification.  Default is 5 pixels.
          flood_fill_prob_thresh - threshold to be used to add burn pixels
              from the burn probability image to the burn classification via
              flood filling; default is 75%
          log_handler - file handler for the log file; if this is None then
              informational/error messages will be written to stdout
        
        Returns:
          nFill - number of pixels that were flood filled
        
        Notes:
          1. The default lower threshold for flood filling is 75% burn
             probability.
        """

        # set up array to hold filled region labels, initialize to 0s which
        # means unburned
        bp_regions = numpy.zeros_like(bp_image, dtype=numpy.int32)

        # find regions to start the flood fill from; these regions are seed
        # pixels that are greater than the seed probability threshold
        bp_seeds = bp_image >= seed_prob_thresh

        # group the seed pixels into regions of connected components; these
        # regions will be the start of the flood-fill algorithm
        bp_seed_regions = numpy.zeros_like(bp_seeds, dtype=numpy.int32)
        n_seed_labels = scipy.ndimage.label(bp_seeds, output=bp_seed_regions)
        msg = 'Found %d seeds to use for flood fill' % n_seed_labels
        logIt(msg, log_handler)

        # get list of region pixel coordinates, use the first pixel from each
        # as the seed for the region
        bp_region_coords = skimage.measure.regionprops(  \
            label_image=bp_seed_regions)

        # loop through regions and flood fill to expand them where they are of
        # an appropriate size
        for i in range(0, len(bp_region_coords)):
            temp_label = bp_region_coords[i]['label']
            temp_area = bp_region_coords[i]['area']
            temp_coords = bp_region_coords[i]['coords'][0]

            # if the number of pixels in this region exceeds the seed size
            # threshold then process the region by flood-filling to grow the
            # burn area
            if temp_area >= seed_size_thresh:
                col = temp_coords[1]
                row = temp_coords[0]

                nFilled = self.floodFill(input_image=bp_image, row=row,  \
                    col=col, output_image=bp_regions, output_label=temp_label, \
                    local_threshold=flood_fill_prob_thresh, nodata=-9999)

                if False:
                    print '#############################################'  \
                          '###############'
                    print 'Seed region label:', temp_label
                    print 'Number of pixels in region:', temp_area
                    print 'First coordinate:', temp_coords
                    print 'Filled pixels:', nFilled

        # find region properties for the flood filled burn areas
        bc2 = bp_regions > 0
        bp_regions2 = numpy.zeros_like(bc2, dtype=numpy.int32)
        n_labels = scipy.ndimage.label(bc2, output=bp_regions2)
        prop_names = ['area','filled_area','max_intensity','mean_intensity',  \
            'min_intensity']
        bp_region2_props = skimage.measure.regionprops(  \
            label_image=bp_regions2, intensity_image=bp_image)

        # define the RAT (raster attribute table)
        #print 'Creating raster attribute table...'
        label_rat = gdal.RasterAttributeTable()
        label_rat.CreateColumn("Value", gdalconst.GFT_Integer,  \
            gdalconst.GFU_MinMax)

        for prop in prop_names:
            label_rat.CreateColumn(prop, gdalconst.GFT_Real,  \
                gdalconst.GFU_MinMax)

        # resize the RAT
        label_rat.SetRowCount(n_labels)

        # set values in the RAT
        #print 'Populating raster attribute table...'
        for i in range(0, n_labels):
            # label id
            label_rat.SetValueAsInt(i, 0, bp_region2_props[i]['label'])

            for j in range(0, len(prop_names)):
                temp_prop = prop_names[j]
                label_rat.SetValueAsDouble(i, j+1,  \
                    float(bp_region2_props[i][temp_prop]))

        return ([bp_regions2, label_rat])
Esempio n. 22
0
    def make_pastures_mask(self, raster_fn, ranch, dst_fn, nodata=-9999):
        """

        :param raster_fn: utm raster
        :param ranches:
        :param dst_fn:
        :param nodata:
        :return:
        """

        assert _exists(_split(dst_fn)[0])
        assert dst_fn.endswith('.tif')

        ds = rasterio.open(raster_fn)
        ds_proj4 = ds.crs.to_proj4()

        loc_path = self.loc_path
        _d = self._d

        sf_fn = _join(loc_path, _d['sf_fn'])
        sf_feature_properties_key = _d['sf_feature_properties_key']
        sf_fn = os.path.abspath(sf_fn)
        sf = fiona.open(sf_fn, 'r')

        reverse_key = self.reverse_key
        pastures = {}
        pastures_mask = np.zeros(ds.shape, dtype=np.uint16)

        for feature in sf:
            properties = feature['properties']
            key = properties[sf_feature_properties_key].replace(' ', '_')

            if not reverse_key:
                _pasture, _ranch = key.split(self.key_delimiter)
            else:
                _ranch, _pasture = key.split(self.key_delimiter)

            if _ranch.lower() != ranch.lower():
                continue

            if _pasture not in pastures:
                pastures[_pasture] = len(pastures) + 1

            # true where valid
            _features = transform_geom(sf.crs_wkt, ds_proj4,
                                       feature['geometry'])
            _mask, _, _ = raster_geometry_mask(ds, [_features])
            k = pastures[_pasture]

            # update pastures_mask
            pastures_mask[np.where(_mask == False)] = k

        utm_dst_fn = ''
        try:
            head, tail = _split(dst_fn)
            utm_dst_fn = _join(head, tail.replace('.tif', '.utm.tif'))
            dst_vrt_fn = _join(head, tail.replace('.tif', '.wgs.vrt'))
            dst_wgs_fn = _join(head, tail.replace('.tif', '.wgs.tif'))

            with rasterio.Env():
                profile = ds.profile
                dtype = rasterio.uint16
                profile.update(count=1,
                               dtype=rasterio.uint16,
                               nodata=nodata,
                               compress='lzw')

                with rasterio.open(utm_dst_fn, 'w', **profile) as dst:
                    dst.write(pastures_mask.astype(dtype), 1)

            assert _exists(utm_dst_fn)
        except:
            raise

        try:

            if _exists(dst_vrt_fn):
                os.remove(dst_vrt_fn)

            cmd = [
                'gdalwarp', '-t_srs', 'EPSG:4326', '-of', 'vrt', utm_dst_fn,
                dst_vrt_fn
            ]
            p = Popen(cmd)
            p.wait()

            assert _exists(dst_vrt_fn)
        except:
            if _exists(dst_vrt_fn):
                os.remove(dst_vrt_fn)
            raise

        try:
            if _exists(dst_fn):
                os.remove(dst_fn)

            cmd = [
                'gdal_translate', '-co', 'COMPRESS=LZW', '-of', 'GTiff',
                dst_vrt_fn, dst_wgs_fn
            ]
            p = Popen(cmd)
            p.wait()

            assert _exists(dst_wgs_fn)
        except:
            if _exists(dst_wgs_fn):
                os.remove(dst_wgs_fn)
            raise

        if _exists(dst_vrt_fn):
            os.remove(dst_vrt_fn)

        for OUTPUT_RASTER in (dst_wgs_fn, utm_dst_fn):
            # https://gdal.org/python/osgeo.gdal.RasterAttributeTable-class.html
            # https://gdal.org/python/osgeo.gdalconst-module.html
            ds = gdal.Open(OUTPUT_RASTER)
            rb = ds.GetRasterBand(1)

            # Create and populate the RAT
            rat = gdal.RasterAttributeTable()
            rat.CreateColumn('VALUE', gdal.GFT_Integer, gdal.GFU_Generic)
            rat.CreateColumn('PASTURE', gdal.GFT_String, gdal.GFU_Generic)

            for i, (pasture, key) in enumerate(pastures.items()):
                rat.SetValueAsInt(i, 0, key)
                rat.SetValueAsString(i, 1, pasture)

            # Associate with the band
            rb.SetDefaultRAT(rat)

            # Close the dataset and persist the RAT
            ds = None
Esempio n. 23
0
def writeColumnToBand(gdalBand,
                      colName,
                      sequence,
                      colType=None,
                      colUsage=gdal.GFU_Generic):
    """
    Given a GDAL band, Writes the data specified in sequence 
    (can be list, tuple or array etc)
    to the named column in the attribute table assocated with the
    gdalBand. colType must be one of gdal.GFT_Integer,gdal.GFT_Real,gdal.GFT_String.
    can specify one of the gdal.GFU_* constants for colUsage - default is 'generic'
    GDAL dataset must have been created, or opened with GA_Update
    """

    if colType is None:
        colType = inferColumnType(sequence)
    if colType is None:
        msg = "Can't infer type of column for sequence of %s" % type(
            sequence[0])
        raise rioserrors.AttributeTableTypeError(msg)

    # check it is acually a valid type
    elif colType not in (gdal.GFT_Integer, gdal.GFT_Real, gdal.GFT_String):
        msg = "coltype must be a valid gdal column type"
        raise rioserrors.AttributeTableTypeError(msg)

    attrTbl = gdalBand.GetDefaultRAT()
    if attrTbl is None:
        # some formats eg ENVI return None
        # here so we need to be able to cope
        attrTbl = gdal.RasterAttributeTable()
        isFileRAT = False
    else:

        isFileRAT = True

        # but if it doesn't support dynamic writing
        # we still ahve to call SetDefaultRAT
        if not attrTbl.ChangesAreWrittenToFile():
            isFileRAT = False

    # We need to ensure colname doesn't already exist
    colExists = False
    for n in range(attrTbl.GetColumnCount()):
        if attrTbl.GetNameOfCol(n) == colName:
            colExists = True
            colNum = n
            break
    if not colExists:
        # preserve usage
        attrTbl.CreateColumn(colName, colType, colUsage)
        colNum = attrTbl.GetColumnCount() - 1

    rowsToAdd = len(sequence)
    # Imagine has trouble if not 256 items for byte
    if gdalBand.DataType == gdal.GDT_Byte:
        rowsToAdd = 256

    # another hack to hide float (0-1) and int (0-255)
    # color table handling.
    # we assume that the column has already been created
    # of the right type appropriate for the format (maybe by calcstats)
    usage = attrTbl.GetUsageOfCol(colNum)
    if (isColorColFromUsage(usage)
            and attrTbl.GetTypeOfCol(colNum) == gdal.GFT_Real
            and colType == gdal.GFT_Integer):
        sequence = numpy.array(sequence, dtype=numpy.float)
        sequence = sequence / 255.0

    attrTbl.SetRowCount(rowsToAdd)
    attrTbl.WriteArray(sequence, colNum)

    if not isFileRAT:
        # assume existing bands re-written
        # Use GDAL's exceptions to trap the error message which arises when
        # writing to a format which does not support it
        usingExceptions = gdal.GetUseExceptions()
        gdal.UseExceptions()
        try:
            gdalBand.SetDefaultRAT(attrTbl)
        except Exception:
            pass
        if not usingExceptions:
            gdal.DontUseExceptions()
Esempio n. 24
0
    def run(self):
        """Run method that performs all the real work"""
        # show the dialog
        self.dlg.show()

        # add all raster layers in current session to UI as potential inputs
        layers = QgsMapLayerRegistry.instance().mapLayers().values()
        for layer in layers:
            if layer.type() == QgsMapLayer.RasterLayer:
                self.dlg.rasterBox.addItem(layer.name(), layer)

        # Run the dialog event loop
        result = self.dlg.exec_()

        # TODO: add logic to auto-detect band and sensor using input_raster

        # See if OK was pressed
        if result:
            # get variable names from input
            input_raster = str(self.dlg.rasterBox.currentText())
            band = str(self.dlg.bandBox.currentText())
            sensor = str(self.dlg.sensorBox.currentText())
            rm_low = self.dlg.rmLowBox.isChecked()

            # use gdal to get unique values
            ds = gdal.Open(input_raster)
            rb = ds.GetRasterBand(1)
            values = sorted(list(np.unique(np.array(rb.ReadAsArray()))))
            #ds = None

            # define lookup table
            bit_flags = lookup_dict.bit_flags
            #qa_values = lookup_dict.qa_values

            # convert input_sensor to sensor values used in qa_values
            if sensor == "Landsat 4-5, 7":
                sens = "L47"
            elif sensor == "Landsat 8":
                sens = "L8"
            else:
                sys.exit("Incorrect sensor provided. Input: {0}; Potential "
                         "options: Landsat 4-5, 7; Landsat 8".format(sensor))

            # get all possible bit values for sensor and band combination
            bit_values = sorted(bit_flags[band][sens].values())
            qa_labels = []
            for row in values:
                bit_bool = []
                for bv in bit_values:
                    if len(bv) == 1:  # single bit
                        bit_bool.append(row & 1 << bv[0] > 0)

                    elif len(bv) > 1:  # 2+ bits
                        bits = []
                        for b in bv:
                            bits.append(row & 1 << b > 0)
                        if all(item == True for item in bits):
                            bit_bool.append(True)
                        else:
                            bit_bool.append(False)

                    else:
                        sys.exit("No valid bits found for target band.")
                '''
                NEW logic for getting labels using bit wise dictionary
                '''
                # create description of each value based upon all possible bits
                true_bits = [i for (i, bb) in zip(bit_values, bit_bool) if bb]

                # if double bits exist, eliminate single bit descriptions,
                #   otherwise, the descriptions will duplicate themselves.
                bb_double = [len(i) > 1 for i in true_bits]
                if any(bb_double):
                    # get only the double bits
                    dbit_nest = [
                        i for (i, db) in zip(true_bits, bb_double) if db
                    ]

                    # collapse the bits into a single list
                    dbits = [item for sublist in dbit_nest for item in sublist]

                    # remove matching single bits out of true_bits list
                    tbo = []
                    for t in true_bits:
                        tb_out = []
                        for d in dbits:
                            if t[0] != d or len(t) > 1:
                                tb_out.append(True)
                            else:
                                tb_out.append(False)
                        if all(tb_out):
                            tbo.append(t)

                    # replace true_bits with filtered list
                    true_bits = tbo

                def get_label(bits):
                    """
                    Generate label for value in attribute table.

                    :param bits: <list> List of True or False for bit position
                    :return: <str> Attribute label
                    """
                    if len(bits) == 0:
                        if band == 'radsat_qa':
                            return 'No Saturation'

                        elif band == 'sr_cloud_qa' or band == 'sr_aerosol':
                            return 'None'

                        elif band == 'BQA':
                            return 'Not Determined'

                    # build description from all bits represented in value
                    desc = []
                    for tb in bits:
                        k = next(
                            key
                            for key, value in bit_flags[band][sens].items()
                            if value == tb)

                        # if 'low' labels are disabled, do not add them here
                        if rm_low and band != 'BQA' and 'low' in k.lower():
                            continue

                        # if last check, and not radiometric sat, set to 'clear'
                        elif rm_low and band == 'BQA' and 'low' in k.lower() \
                                and tb == bits[-1] and \
                                        'radiometric' not in k.lower() and \
                                not desc:
                            k = 'Clear'

                        # if BQA and bit is low radiometric sat, keep it
                        elif rm_low and band == 'BQA' and 'low' in k.lower():
                            if 'radiometric' not in k.lower():
                                continue

                        # if radsat_qa, handle differently to display cleaner
                        if band == 'radsat_qa':
                            if not desc:
                                desc = "Band {0} Data Saturation".format(tb[0])

                            else:
                                desc = "{0},{1} Data Saturation".format(
                                    desc[:desc.find('Data') - 1], tb[0])

                        # string creation for all other bands
                        else:
                            if not desc:
                                desc = "{0}".format(k)

                            else:
                                desc += ", {0}".format(k)

                    # final check to make sure something was set
                    if not desc:
                        desc = 'ERROR: bit set incorrectly'

                    return desc

                # add desc to row description
                qa_labels.append(get_label(true_bits))
            '''
            OLD logic for getting lookup values

            # use unique raster values (and sensor+band pair) to get defs
            if band == 'radsat_qa':
                qa_labels = {i:qa_values[band][i] for i in qa_values[band] if i
                             in list(values)}

            elif band == 'pixel_qa' and sens == 'L8':  # terrain occl. check
                qa_labels = {}
                for i in qa_values[band]:
                    if i >= 1024:
                        qa_labels[i] = 'Terrain occlusion'
                    else:
                        qa_labels[i] = qa_values[band][sens][i]

            else:
                qa_labels = {i:qa_values[band][sens][i] for i in
                             qa_values[band][sens] if i in list(values)}

            '''
            '''
            Use gdal.RasterAttributeTable to embed qa values in raster
            '''
            # create table
            rat = gdal.RasterAttributeTable()

            # get column count (for indexing columns)
            rat_cc = rat.GetColumnCount()

            # add 'value' and 'descr' columns to table
            rat.CreateColumn("Value", gdalconst.GFT_Integer,
                             gdalconst.GFU_MinMax)
            rat.CreateColumn("Descr", gdalconst.GFT_String,
                             gdalconst.GFU_MinMax)

            # populate table with contents of 'qa_labels'
            uid = 0
            for val, lab in zip(values, qa_labels):

                # 'value' column
                rat.SetValueAsInt(uid, rat_cc, int(val))

                # 'descr' column
                rat.SetValueAsString(uid, rat_cc + 1, lab)

                uid += 1

            # set raster attribute table to raster
            rb.SetDefaultRAT(rat)
            '''
            METHOD 1: use RasterAttributeTable to display values.

            QGIS' UI does not currently support reading Attribute Tables
            embedded in raster datasets. Instead, we'll assign labels and
            random colors to the raster's color palette in the QGIS UI.

            Feature request: https://issues.qgis.org/issues/4321

            # open raster with QGIS API
            q_raster = QgsRasterLayer(input_raster,
                                      os.path.basename(input_raster))
            # make sure the raster is valid
            if not q_raster.isValid():
                sys.exit("Layer {0} not valid!".format(input_raster))


            # save changes and close raster
            ds = None

            # add raster to QGIS interface
            QgsMapLayerRegistry.instance().addMapLayer(q_raster)
            '''
            '''
            METHOD 2: re-assign colors in QGIS
            '''
            # open raster
            q_raster = QgsRasterLayer(input_raster,
                                      os.path.basename(input_raster))
            if not q_raster.isValid():
                sys.exit("Layer {0} not valid!".format(input_raster))

            # define color shader
            shader = QgsRasterShader()

            # define ramp for color shader
            c_ramp_shader = QgsColorRampShader()
            c_ramp_shader.setColorRampType(QgsColorRampShader.EXACT)

            # assign a random color to each value, and apply label
            c_ramp_vals = []
            for val, lab in zip(values, qa_labels):
                c_ramp_vals.append(
                    QgsColorRampShader.ColorRampItem(
                        float(val), QColor('#%06x' % randint(0, 2**24)), lab))

            # apply new color/label combo to color ramps
            c_ramp_shader.setColorRampItemList(c_ramp_vals)
            shader.setRasterShaderFunction(c_ramp_shader)

            # apply color ramps to raster
            ps_ramp = QgsSingleBandPseudoColorRenderer(q_raster.dataProvider(),
                                                       1, shader)
            q_raster.setRenderer(ps_ramp)

            # add raster to QGIS interface
            QgsMapLayerRegistry.instance().addMapLayer(q_raster)
Esempio n. 25
0
def gdal_api_proxy_sub():

    src_ds = gdal.Open('data/byte.tif')
    src_cs = src_ds.GetRasterBand(1).Checksum()
    src_gt = src_ds.GetGeoTransform()
    src_prj = src_ds.GetProjectionRef()
    src_data = src_ds.ReadRaster(0, 0, 20, 20)
    src_md = src_ds.GetMetadata()
    src_ds = None

    drv = gdal.IdentifyDriver('data/byte.tif')
    if drv.GetDescription() != 'API_PROXY':
        gdaltest.post_reason('fail')
        return 'fail'

    ds = gdal.GetDriverByName('GTiff').Create('tmp/byte.tif', 1, 1, 3)
    ds = None

    src_ds = gdal.Open('data/byte.tif')
    ds = gdal.GetDriverByName('GTiff').CreateCopy('tmp/byte.tif',
                                                  src_ds,
                                                  options=['TILED=YES'])
    got_cs = ds.GetRasterBand(1).Checksum()
    if src_cs != got_cs:
        gdaltest.post_reason('fail')
        return 'fail'
    ds = None

    ds = gdal.Open('tmp/byte.tif', gdal.GA_Update)

    ds.SetGeoTransform([1, 2, 3, 4, 5, 6])
    got_gt = ds.GetGeoTransform()
    if src_gt == got_gt:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.SetGeoTransform(src_gt)
    got_gt = ds.GetGeoTransform()
    if src_gt != got_gt:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetGCPCount() != 0:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetGCPProjection() != '':
        print(ds.GetGCPProjection())
        gdaltest.post_reason('fail')
        return 'fail'

    if len(ds.GetGCPs()) != 0:
        gdaltest.post_reason('fail')
        return 'fail'

    gcps = [gdal.GCP(0, 1, 2, 3, 4)]
    ds.SetGCPs(gcps, "foo")

    got_gcps = ds.GetGCPs()
    if len(got_gcps) != 1:
        gdaltest.post_reason('fail')
        return 'fail'

    if got_gcps[0].GCPLine != gcps[0].GCPLine or  \
       got_gcps[0].GCPPixel != gcps[0].GCPPixel or  \
       got_gcps[0].GCPX != gcps[0].GCPX or \
       got_gcps[0].GCPY != gcps[0].GCPY:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetGCPProjection() != 'foo':
        gdaltest.post_reason('fail')
        print(ds.GetGCPProjection())
        return 'fail'

    ds.SetGCPs([], "")

    if len(ds.GetGCPs()) != 0:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.SetProjection('')
    got_prj = ds.GetProjectionRef()
    if src_prj == got_prj:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.SetProjection(src_prj)
    got_prj = ds.GetProjectionRef()
    if src_prj != got_prj:
        gdaltest.post_reason('fail')
        print(src_prj)
        print(got_prj)
        return 'fail'

    ds.GetRasterBand(1).Fill(0)
    got_cs = ds.GetRasterBand(1).Checksum()
    if 0 != got_cs:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).WriteRaster(0, 0, 20, 20, src_data)
    got_cs = ds.GetRasterBand(1).Checksum()
    if src_cs != got_cs:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).Fill(0)
    got_cs = ds.GetRasterBand(1).Checksum()
    if 0 != got_cs:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.WriteRaster(0, 0, 20, 20, src_data)
    got_cs = ds.GetRasterBand(1).Checksum()
    if src_cs != got_cs:
        gdaltest.post_reason('fail')
        return 'fail'

    # Not bound to SWIG
    # ds.AdviseRead(0,0,20,20,20,20)

    got_data = ds.ReadRaster(0, 0, 20, 20)
    if src_data != got_data:
        gdaltest.post_reason('fail')
        return 'fail'

    got_data = ds.GetRasterBand(1).ReadRaster(0, 0, 20, 20)
    if src_data != got_data:
        gdaltest.post_reason('fail')
        return 'fail'

    got_data_weird_spacing = ds.ReadRaster(0,
                                           0,
                                           20,
                                           20,
                                           buf_pixel_space=1,
                                           buf_line_space=32)
    if len(got_data_weird_spacing) != 32 * (20 - 1) + 20:
        gdaltest.post_reason('fail')
        print(len(got_data_weird_spacing))
        return 'fail'

    if got_data[20:20 + 20] != got_data_weird_spacing[32:32 + 20]:
        gdaltest.post_reason('fail')
        return 'fail'

    got_data_weird_spacing = ds.GetRasterBand(1).ReadRaster(0,
                                                            0,
                                                            20,
                                                            20,
                                                            buf_pixel_space=1,
                                                            buf_line_space=32)
    if len(got_data_weird_spacing) != 32 * (20 - 1) + 20:
        gdaltest.post_reason('fail')
        print(len(got_data_weird_spacing))
        return 'fail'

    if got_data[20:20 + 20] != got_data_weird_spacing[32:32 + 20]:
        gdaltest.post_reason('fail')
        return 'fail'

    got_block = ds.GetRasterBand(1).ReadBlock(0, 0)
    if len(got_block) != 256 * 256:
        gdaltest.post_reason('fail')
        return 'fail'

    if got_data[20:20 + 20] != got_block[256:256 + 20]:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.FlushCache()
    ds.GetRasterBand(1).FlushCache()

    got_data = ds.GetRasterBand(1).ReadRaster(0, 0, 20, 20)
    if src_data != got_data:
        gdaltest.post_reason('fail')
        return 'fail'

    if len(ds.GetFileList()) != 1:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.AddBand(gdal.GDT_Byte) == 0:
        gdaltest.post_reason('fail')
        return 'fail'

    got_md = ds.GetMetadata()
    if src_md != got_md:
        gdaltest.post_reason('fail')
        print(src_md)
        print(got_md)
        return 'fail'

    if ds.GetMetadataItem('AREA_OR_POINT') != 'Area':
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetMetadataItem('foo') is not None:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.SetMetadataItem('foo', 'bar')
    if ds.GetMetadataItem('foo') != 'bar':
        gdaltest.post_reason('fail')
        return 'fail'

    ds.SetMetadata({'foo': 'baz'}, 'OTHER')
    if ds.GetMetadataItem('foo', 'OTHER') != 'baz':
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).SetMetadata({'foo': 'baw'}, 'OTHER')
    if ds.GetRasterBand(1).GetMetadataItem('foo', 'OTHER') != 'baw':
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetMetadataItem('INTERLEAVE', 'IMAGE_STRUCTURE') != 'BAND':
        gdaltest.post_reason('fail')
        return 'fail'

    if len(ds.GetRasterBand(1).GetMetadata()) != 0:
        gdaltest.post_reason('fail')
        print(ds.GetRasterBand(1).GetMetadata())
        return 'fail'

    if ds.GetRasterBand(1).GetMetadataItem('foo') is not None:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).SetMetadataItem('foo', 'baz')
    if ds.GetRasterBand(1).GetMetadataItem('foo') != 'baz':
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).SetMetadata({'foo': 'baw'})
    if ds.GetRasterBand(1).GetMetadataItem('foo') != 'baw':
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).GetColorInterpretation() != gdal.GCI_GrayIndex:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).SetColorInterpretation(gdal.GCI_Undefined)

    ct = ds.GetRasterBand(1).GetColorTable()
    if ct is not None:
        gdaltest.post_reason('fail')
        return 'fail'

    ct = gdal.ColorTable()
    ct.SetColorEntry(0, (1, 2, 3))
    if ds.GetRasterBand(1).SetColorTable(ct) != 0:
        gdaltest.post_reason('fail')
        return 'fail'

    ct = ds.GetRasterBand(1).GetColorTable()
    if ct is None:
        gdaltest.post_reason('fail')
        return 'fail'
    if ct.GetColorEntry(0) != (1, 2, 3, 255):
        gdaltest.post_reason('fail')
        return 'fail'

    ct = ds.GetRasterBand(1).GetColorTable()
    if ct is None:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).SetColorTable(None) != 0:
        gdaltest.post_reason('fail')
        return 'fail'

    ct = ds.GetRasterBand(1).GetColorTable()
    if ct is not None:
        gdaltest.post_reason('fail')
        return 'fail'

    rat = ds.GetRasterBand(1).GetDefaultRAT()
    if rat is not None:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).SetDefaultRAT(None) != 0:
        gdaltest.post_reason('fail')
        return 'fail'

    ref_rat = gdal.RasterAttributeTable()
    if ds.GetRasterBand(1).SetDefaultRAT(ref_rat) != 0:
        gdaltest.post_reason('fail')
        return 'fail'

    rat = ds.GetRasterBand(1).GetDefaultRAT()
    if rat is not None:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).SetDefaultRAT(None) != 0:
        gdaltest.post_reason('fail')
        return 'fail'

    rat = ds.GetRasterBand(1).GetDefaultRAT()
    if rat is not None:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).GetMinimum() is not None:
        gdaltest.post_reason('fail')
        return 'fail'

    got_stats = ds.GetRasterBand(1).GetStatistics(0, 0)
    if got_stats[3] >= 0.0:
        gdaltest.post_reason('fail')
        return 'fail'

    got_stats = ds.GetRasterBand(1).GetStatistics(1, 1)
    if got_stats[0] != 74.0:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).GetMinimum() != 74.0:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).GetMaximum() != 255.0:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).SetStatistics(1, 2, 3, 4)
    got_stats = ds.GetRasterBand(1).GetStatistics(1, 1)
    if got_stats != [1, 2, 3, 4]:
        print(got_stats)
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).ComputeStatistics(0)
    got_stats = ds.GetRasterBand(1).GetStatistics(1, 1)
    if got_stats[0] != 74.0:
        gdaltest.post_reason('fail')
        return 'fail'

    minmax = ds.GetRasterBand(1).ComputeRasterMinMax()
    if minmax != (74.0, 255.0):
        gdaltest.post_reason('fail')
        print(minmax)
        return 'fail'

    if ds.GetRasterBand(1).GetOffset() != 0.0:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).GetScale() != 1.0:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).SetOffset(10.0)
    if ds.GetRasterBand(1).GetOffset() != 10.0:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).SetScale(2.0)
    if ds.GetRasterBand(1).GetScale() != 2.0:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.BuildOverviews('NEAR', [2])
    if ds.GetRasterBand(1).GetOverviewCount() != 1:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).GetOverview(-1) is not None:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).GetOverview(0) is None:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).GetOverview(0) is None:
        gdaltest.post_reason('fail')
        return 'fail'

    got_hist = ds.GetRasterBand(1).GetHistogram()
    if len(got_hist) != 256:
        gdaltest.post_reason('fail')
        return 'fail'

    (minval, maxval, nitems,
     got_hist2) = ds.GetRasterBand(1).GetDefaultHistogram()
    if minval != -0.5:
        gdaltest.post_reason('fail')
        return 'fail'
    if maxval != 255.5:
        gdaltest.post_reason('fail')
        return 'fail'
    if nitems != 256:
        gdaltest.post_reason('fail')
        return 'fail'
    if got_hist != got_hist2:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).SetDefaultHistogram(1, 2, [3])
    (minval, maxval, nitems,
     got_hist3) = ds.GetRasterBand(1).GetDefaultHistogram()
    if minval != 1:
        gdaltest.post_reason('fail')
        return 'fail'
    if maxval != 2:
        gdaltest.post_reason('fail')
        return 'fail'
    if nitems != 1:
        gdaltest.post_reason('fail')
        return 'fail'
    if got_hist3[0] != 3:
        gdaltest.post_reason('fail')
        return 'fail'

    got_nodatavalue = ds.GetRasterBand(1).GetNoDataValue()
    if got_nodatavalue is not None:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).SetNoDataValue(123)
    got_nodatavalue = ds.GetRasterBand(1).GetNoDataValue()
    if got_nodatavalue != 123:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).GetMaskFlags() != 8:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).GetMaskBand() is None:
        gdaltest.post_reason('fail')
        return 'fail'

    ret = ds.GetRasterBand(1).DeleteNoDataValue()
    if ret != 0:
        gdaltest.post_reason('fail')
        return 'fail'
    got_nodatavalue = ds.GetRasterBand(1).GetNoDataValue()
    if got_nodatavalue is not None:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.CreateMaskBand(0)

    if ds.GetRasterBand(1).GetMaskFlags() != 2:
        print(ds.GetRasterBand(1).GetMaskFlags())
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).GetMaskBand() is None:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).CreateMaskBand(0)

    if ds.GetRasterBand(1).HasArbitraryOverviews() != 0:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).SetUnitType('foo')
    if ds.GetRasterBand(1).GetUnitType() != 'foo':
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).GetCategoryNames() is not None:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).SetCategoryNames(['foo'])
    if ds.GetRasterBand(1).GetCategoryNames() != ['foo']:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).SetDescription('bar')

    ds = None

    gdal.GetDriverByName('GTiff').Delete('tmp/byte.tif')

    return 'success'
Esempio n. 26
0
def writeColumnToBand(gdalBand,
                      colName,
                      sequence,
                      colType=None,
                      colUsage=gdal.GFU_Generic):
    """
    Given a GDAL band, Writes the data specified in sequence 
    (can be list, tuple or array etc)
    to the named column in the attribute table assocated with the
    gdalBand. colType must be one of gdal.GFT_Integer,gdal.GFT_Real,gdal.GFT_String.
    can specify one of the gdal.GFU_* constants for colUsage - default is 'generic'
    GDAL dataset must have been created, or opened with GA_Update
    """

    if colType is None:
        colType = inferColumnType(sequence)
    if colType is None:
        msg = "Can't infer type of column for sequence of %s" % type(
            sequence[0])
        raise rioserrors.AttributeTableTypeError(msg)

    # check it is acually a valid type
    elif colType not in (gdal.GFT_Integer, gdal.GFT_Real, gdal.GFT_String):
        msg = "coltype must be a valid gdal column type"
        raise rioserrors.AttributeTableTypeError(msg)

    # things get a bit weird here as we need different
    # behaviour depending on whether we have an RFC40
    # RAT or not.
    if hasattr(gdal.RasterAttributeTable, "WriteArray"):
        # new behaviour
        attrTbl = gdalBand.GetDefaultRAT()
        if attrTbl is None:
            # some formats eg ENVI return None
            # here so we need to be able to cope
            attrTbl = gdal.RasterAttributeTable()
            isFileRAT = False
        else:

            isFileRAT = True

            # but if it doesn't support dynamic writing
            # we still ahve to call SetDefaultRAT
            if not attrTbl.ChangesAreWrittenToFile():
                isFileRAT = False

    else:
        # old behaviour
        attrTbl = gdal.RasterAttributeTable()
        isFileRAT = False

    # thanks to RFC40 we need to ensure colname doesn't already exist
    colExists = False
    for n in range(attrTbl.GetColumnCount()):
        if attrTbl.GetNameOfCol(n) == colName:
            colExists = True
            colNum = n
            break
    if not colExists:
        # preserve usage
        attrTbl.CreateColumn(colName, colType, colUsage)
        colNum = attrTbl.GetColumnCount() - 1

    rowsToAdd = len(sequence)
    # Imagine has trouble if not 256 items for byte
    if gdalBand.DataType == gdal.GDT_Byte:
        rowsToAdd = 256

    # another hack to hide float (0-1) and int (0-255)
    # color table handling.
    # we assume that the column has already been created
    # of the right type appropriate for the format (maybe by calcstats)
    # Note: this only works post RFC40 when we have an actual reference
    # to the RAT rather than a new one so we can ask GetTypeOfCol
    usage = attrTbl.GetUsageOfCol(colNum)
    if (isColorColFromUsage(usage)
            and attrTbl.GetTypeOfCol(colNum) == gdal.GFT_Real
            and colType == gdal.GFT_Integer):
        sequence = numpy.array(sequence, dtype=numpy.float)
        sequence = sequence / 255.0

    if hasattr(attrTbl, "WriteArray"):
        # if GDAL > 1.10 has these functions
        # thanks to RFC40
        attrTbl.SetRowCount(rowsToAdd)
        attrTbl.WriteArray(sequence, colNum)

    elif HAVE_TURBORAT:
        # use turborat to write values to RAT if available
        if not isinstance(sequence, numpy.ndarray):
            # turborat.writeColumn needs an array
            sequence = numpy.array(sequence)

        # If the dtype of the array is some unicode type, then convert to simple string type,
        # as turborat does not cope with the unicode variant.
        if 'U' in str(sequence.dtype):
            sequence = sequence.astype(numpy.character)

        turborat.writeColumn(attrTbl, colNum, sequence, rowsToAdd)
    else:
        defaultValues = {
            gdal.GFT_Integer: 0,
            gdal.GFT_Real: 0.0,
            gdal.GFT_String: ''
        }

        # go thru and set each value into the RAT
        for rowNum in range(rowsToAdd):
            if rowNum >= len(sequence):
                # they haven't given us enough values - fill in with default
                val = defaultValues[colType]
            else:
                val = sequence[rowNum]

            if colType == gdal.GFT_Integer:
                # appears that swig cannot convert numpy.int64
                # to the int type required by SetValueAsInt
                # so we need to cast.
                # This is a problem as readColumn returns numpy.int64
                # for integer columns.
                # Seems fine converting numpy.float64 to
                # float however for SetValueAsDouble.
                attrTbl.SetValueAsInt(rowNum, colNum, int(val))
            elif colType == gdal.GFT_Real:
                attrTbl.SetValueAsDouble(rowNum, colNum, float(val))
            else:
                attrTbl.SetValueAsString(rowNum, colNum, val)

    if not isFileRAT:
        # assume existing bands re-written
        # Use GDAL's exceptions to trap the error message which arises when
        # writing to a format which does not support it
        usingExceptions = gdal.GetUseExceptions()
        gdal.UseExceptions()
        try:
            gdalBand.SetDefaultRAT(attrTbl)
        except Exception:
            pass
        if not usingExceptions:
            gdal.DontUseExceptions()
    def save_as_xml(self, raster_source, band) -> bool:
        """Saves .aux.xml RAT using GDAL

        :param raster_source: path of of the raster data file
        :type raster_source: str
        :param band: band number
        :type band: int
        :return: TRUE on success
        :rtype: bool
        """

        ds = gdal.OpenEx(raster_source, gdal.OF_RASTER | gdal.OF_UPDATE)
        if ds:
            self.band = band
            gdal_band = ds.GetRasterBand(band)
            if gdal_band:
                rat = gdal.RasterAttributeTable()
                rat.SetTableType(self.thematic_type)
                for field in list(self.fields.values()):
                    rat.CreateColumn(field.name, field.type, field.usage)

                type_map = {gdal.GFT_Integer: 'Int',
                            gdal.GFT_Real: 'Double', gdal.GFT_String: 'String'}

                column_index = 0


                for field_name, field in self.fields.items():
                    values = self.data[field_name]
                    func = getattr(rat, 'SetValueAs%s' % type_map[field.type])

                    for row_index in range(len(values)):
                        rat_log('Writing RAT value as %s, (%s, %s) %s' %
                                (type_map[field.type], row_index, column_index, values[row_index]))
                        value = html.escape(values[row_index]) if field.type == gdal.GFT_String else values[row_index]
                        func(row_index, column_index, value)

                    column_index += 1

                assert rat.GetColumnCount() == len(self.fields)
                assert rat.GetRowCount() == len(self.values[0])

                # Ugly hack because GDAL does not know about the newly created RAT
                for layer in [l for l in QgsProject.instance().mapLayers().values() if l.source() == raster_source]:
                    RAT._dirty_xml_rats["%s|%s" %
                                        (self.band, self.path)] = self
                    if layer.id() not in RAT._dirty_xml_layer_ids:
                        RAT._dirty_xml_layer_ids.append(layer.id())
                        layer.destroyed.connect(self._restore_xml_rats)

                gdal_band.SetDefaultRAT(rat)
                ds.FlushCache()
                # I don't know why but seems like you need to call this twice or
                # the RAT is not really saved into the XML
                gdal_band.SetDefaultRAT(rat)
                ds.FlushCache()
                rat_log('RAT saved as XML for layer %s' % raster_source)

                return True

        return False