Ejemplo n.º 1
0
def create_colortable(ct_file):
    ct_entries = parse_color_table_file(ct_file)
    ct = gdal.ColorTable()
    for entry in ct_entries:
        ct.SetColorEntry(entry[0], tuple(entry[1:]))

    return ct
Ejemplo n.º 2
0
def load_color_table(nc_vname):
    """Load a color table for use with a palettized GeoTIFF
       :param string nc_vname: name of the variable which maps to a color table.  The color tables are stored in the directory color_tables with a filename extension of .ct
    """

    if nc_vname in ["tmp", "tmx", "tmn"]:
        ct_name = "tmp"
    elif nc_vname in ["cld"]:
        ct_name = "cld"
    elif nc_vname in ["wet", "vap", "pre", "pet"]:
        ct_name = "pre"
    elif nc_vname in ["frs"]:
        ct_name = "frs"
    elif nc_vname in ["dtr"]:
        ct_name = "dtr"
    else:
        raise Exception("No color table defined for variable: " + nc_vname)

    fname = "color_tables/" + ct_name + ".ct"
    # file is just a text file, use readlines
    fh = open(fname, 'r')
    lines = fh.readlines()

    # create GDAL color table
    ct = gdal.ColorTable()

    # first line is length of color table
    ct_length = int(lines[0])
    for p in range(0, ct_length):
        c = lines[p + 1].split(",")
        ct.SetColorEntry(p, (int(c[0]), int(c[1]), int(c[2]), int(c[3])))
    return ct
Ejemplo n.º 3
0
def CreateColorTable(fileLUT, logger_=logger):
    """
    IN :
        fileLUT [string] : path to the color file table
            ex : for a table containing 3 classes ("8","90","21"), "8" must be represent in red, "90" in green, "21" in blue
                cat /path/to/myColorTable.csv
                8 255 0 0
                90 0 255 0
                21 0 255 0
    OUT :
        ct [gdalColorTable]
    """
    filein = open(fileLUT)
    ct = gdal.ColorTable()
    for line in filein:
        entry = line
        classID = entry.split(" ")
        codeColor = [int(i) for i in (classID[1:4])]
        try:
            ct.SetColorEntry(int(classID[0]), tuple(codeColor))
        except:
            logger_.warning(
                "a color entry was not recognize, default value set. Class label 0, RGB code : 255, 255, 255"
            )
            ct.SetColorEntry(0, (255, 255, 255))
    filein.close()
    return ct
Ejemplo n.º 4
0
def pcidsk_7():

    if gdaltest.pcidsk_new == 0:
        return 'skip'

    # Write out some metadata to the default and non-default domain and
    # using the set and single methods.
    band = gdaltest.pcidsk_ds.GetRasterBand(1)

    ct = band.GetColorTable()

    if ct is not None:
        gdaltest.post_reason('Got color table unexpectedly.')
        return 'fail'

    ct = gdal.ColorTable()
    ct.SetColorEntry(0, (0, 255, 0, 255))
    ct.SetColorEntry(1, (255, 0, 255, 255))
    ct.SetColorEntry(2, (0, 0, 255, 255))
    band.SetColorTable(ct)

    ct = band.GetColorTable()

    if ct.GetColorEntry(1) != (255, 0, 255, 255):
        gdaltest.post_reason('Got wrong color table entry immediately.')
        return 'fail'

    ct = None
    band = None

    # Close and reopen.
    gdaltest.pcidsk_ds = None
    gdaltest.pcidsk_ds = gdal.Open('tmp/pcidsk_5.pix', gdal.GA_Update)

    band = gdaltest.pcidsk_ds.GetRasterBand(1)

    ct = band.GetColorTable()

    if ct.GetColorEntry(1) != (255, 0, 255, 255):
        gdaltest.post_reason('Got wrong color table entry after reopen.')
        return 'fail'

    if band.GetColorInterpretation() != gdal.GCI_PaletteIndex:
        gdaltest.post_reason('Not a palette?')
        return 'fail'

    if band.SetColorTable(None) != 0:
        gdaltest.post_reason('SetColorTable failed.')
        return 'fail'

    if band.GetColorTable() is not None:
        gdaltest.post_reason('color table still exists!')
        return 'fail'

    if band.GetColorInterpretation() != gdal.GCI_Undefined:
        gdaltest.post_reason('Paletted?')
        return 'fail'

    return 'success'
Ejemplo n.º 5
0
def add_colortable(n_out, source):
    """ Add colortable to output GDAL Dataset """
    colorTable = gdal.ColorTable()
    for color in colorDict[source]:
        colorTable.SetColorEntry(color, colorDict[source][color] + (255, ))
    n_out.vrt.dataset.GetRasterBand(1).SetColorTable(colorTable)

    return n_out
Ejemplo n.º 6
0
def palette2colortable(pal):
    """
    """
    import gdal
    colortable = gdal.ColorTable()
    for index in range(0, len(pal)/3):
        entry = [int(col) for col in pal[index*3:index*3+3]]
        colortable.SetColorEntry(index, tuple(entry))
    return colortable
Ejemplo n.º 7
0
def colortable_1():

    gdaltest.test_ct_data = [(255, 0, 0), (0, 255, 0), (0, 0, 255),
                             (255, 255, 255, 0)]

    gdaltest.test_ct = gdal.ColorTable()
    for i in range(len(gdaltest.test_ct_data)):
        gdaltest.test_ct.SetColorEntry(i, gdaltest.test_ct_data[i])

    return 'success'
Ejemplo n.º 8
0
def create_color_table(R, G, B):
    """Create a color table for use with a palettized GeoTIFF
       :param array(byte) R: array / list of reds, length n
       :param array(byte) G: array / list of greens, length n
       :param array(byte) B: array / list of blues, length n
    """
    assert (len(R) == len(G) and len(G) == len(B))

    # create a GDAL color table
    ct = gdal.ColorTable()
    for p in range(0, len(R)):
        ct.SetColorEntry(p, (R[p], G[p], B[p], 255))
    return ct
def RasterPalette(paletteT):
    PcR, AT, maxAT = FixGDALPalette(paletteT)
    ct = gdal.ColorTable()
    '''
        For discrete colors
        #ct.CreateColorRamp(0,(178,223,138),5,(255,127,0))
        #ct.CreateColorRamp(Pcr)
        for c in PcR:

            ct.SetColorEntry(c[0],c[1])
        '''
    #for color ramps
    for c in range(1, len(PcR)):
        ct.CreateColorRamp(PcR[c - 1][0], PcR[c - 1][1], PcR[c][0], PcR[c][1])
    return ct
    '''
Ejemplo n.º 10
0
def create_colortable(ct_file_or_entries):
    """Create GDAL ColorTable object from Colormap object."""
    if isinstance(ct_file_or_entries, str):
        ct_entries = parse_color_table_file(ct_file_or_entries)
    elif isinstance(ct_file_or_entries, Colormap):
        ct_entries = enumerate(ct_file_or_entries.colors)
        ct_entries = (((x, ) + tuple(int(c * 255.) for c in color))
                      for x, color in ct_entries)
    else:
        ct_entries = ct_file_or_entries

    ct = gdal.ColorTable()
    for entry in ct_entries:
        ct.SetColorEntry(entry[0], tuple(entry[1:]))

    return ct
Ejemplo n.º 11
0
def gen_GTiff(pgm_lines, gtiff_fname, proj4string):
    pgm_header_info = parse_header_lines(pgm_lines[0:7])
    pgm_data_array = gen_array(pgm_lines)
    pgm_shape = pgm_data_array.shape
    val_array = pgm_data_array * float(pgm_header_info['gain']) + float(
        pgm_header_info['offset'])
    color_idx_array = Natural_Breaks(val_array.ravel(), 8, 0).yb
    color_idx_array.shape = pgm_shape
    #Select driver
    driver = gdal.GetDriverByName('GTiff')

    #Create file
    dataset = driver.Create(gtiff_fname, pgm_header_info['width'],
                            pgm_header_info['height'], 1, gdal.GDT_Byte)
    epsg = Proj(proj4string)
    logging.debug('Lower left coordinate (long, lat) : (%f, %f)' %
                  (pgm_header_info['lon0'], pgm_header_info['lat0']))
    x_bl, y_bl = epsg(pgm_header_info['lon0'], pgm_header_info['lat0'])
    logging.debug('Lower left coordinates (x, y): (%f, %f)' % (x_bl, y_bl))
    x_ul = x_bl
    y_ul = y_bl + pgm_header_info['pixsize'] * pgm_header_info['height']
    x0 = x_ul
    # Check reference
    y0 = y_ul

    dataset.SetGeoTransform([
        x0, pgm_header_info['pixsize'], 0, y0, 0,
        -1 * pgm_header_info['pixsize']
    ])
    srs = osr.SpatialReference()
    srs.ImportFromProj4(proj4string)

    dataset.SetProjection(srs.ExportToWkt())
    dataset.GetRasterBand(1).WriteArray(color_idx_array)
    dataset.GetRasterBand(1).SetColorInterpretation(gdal.GCI_PaletteIndex)

    color_table = gdal.ColorTable(gdal.GPI_RGB)
    ce = gdal.ColorEntry()
    ce.__dict__['this'] = (255, 255, 255, 255)
    color_table.SetColorEntry(0, ce.__dict__['this'])
    for i in range(7):
        ce = gdal.ColorEntry()
        ce.__dict__['this'] = YlOrRd[7][i]
        color_table.SetColorEntry(i + 1, ce.__dict__['this'])
    dataset.GetRasterBand(1).SetColorTable(color_table)

    dataset = None
Ejemplo n.º 12
0
def colortable_3():

    ct = gdal.ColorTable()
    try:
        ct.CreateColorRamp
    except:
        return 'skip'

    ct.CreateColorRamp(0, (255, 0, 0), 255, (0, 0, 255))

    if ct.GetColorEntry(0) != (255, 0, 0, 255):
        return 'fail'

    if ct.GetColorEntry(255) != (0, 0, 255, 255):
        return 'fail'

    return 'success'
Ejemplo n.º 13
0
def write_geotiff(fname,
                  data,
                  geo_transform,
                  projection,
                  classes,
                  COLORS,
                  data_type=gdal.GDT_Byte):
    """
    Create a GeoTIFF file with the given data.
    :param fname: Path to a directory with shapefiles
    :param data: Number of rows of the result
    :param geo_transform: Returned value of gdal.Dataset.GetGeoTransform (coefficients for
                          transforming between pixel/line (P,L) raster space, and projection
                          coordinates (Xp,Yp) space.
    :param projection: Projection definition string (Returned by gdal.Dataset.GetProjectionRef)
    """
    driver = gdal.GetDriverByName('GTiff')
    rows, cols = data.shape
    dataset = driver.Create(fname, cols, rows, 1, data_type)
    dataset.SetGeoTransform(geo_transform)
    dataset.SetProjection(projection)
    band = dataset.GetRasterBand(1)
    band.WriteArray(data)

    ct = gdal.ColorTable()
    for pixel_value in range(len(classes) + 1):
        color_hex = COLORS[pixel_value]
        r = int(color_hex[1:3], 16)
        g = int(color_hex[3:5], 16)
        b = int(color_hex[5:7], 16)
        ct.SetColorEntry(pixel_value, (r, g, b, 255))
    band.SetColorTable(ct)

    metadata = {
        'TIFFTAG_COPYRIGHT': 'CC BY 4.0, AND BEN HICKSON',
        'TIFFTAG_DOCUMENTNAME': 'Land Cover Classification',
        'TIFFTAG_IMAGEDESCRIPTION':
        'Random Forests Supervised classification.',
        'TIFFTAG_MAXSAMPLEVALUE': str(len(classes)),
        'TIFFTAG_MINSAMPLEVALUE': '0',
        'TIFFTAG_SOFTWARE': 'Python, GDAL, scikit-learn'
    }
    dataset.SetMetadata(metadata)

    dataset = None  # Close the file
    return
Ejemplo n.º 14
0
def colortable(ctype):
    """
    Generates a gdal-ingestible color table for a set of pre-defined options.
    Can add your own colortable options. See https://gdal.org/doxygen/structGDALColorEntry.html
    and https://gis.stackexchange.com/questions/158195/python-gdal-create-geotiff-from-array-with-colormapping
    for guidance.

    Parameters
    ----------
    ctype : str
        Specifies the type of colortable to return. Choose from
        {'binary', 'skel', 'mask', 'tile', or 'GSW'}.

    Returns
    -------
    color_table : gdal.ColorTable()
        Color table that can be supplied to gdal when creating a raster.

    """

    color_table = gdal.ColorTable()

    if ctype == 'binary':
        # Some examples / last value is alpha (transparency).
        color_table.SetColorEntry(0, (0, 0, 0, 0))
        color_table.SetColorEntry(1, (255, 255, 255, 100))
    elif ctype == 'skel':
        color_table.SetColorEntry(0, (0, 0, 0, 0))
        color_table.SetColorEntry(1, (255, 0, 255, 100))
    elif ctype == 'mask':
        color_table.SetColorEntry(0, (0, 0, 0, 0))
        color_table.SetColorEntry(1, (0, 128, 0, 100))
    elif ctype == 'tile':
        color_table.SetColorEntry(0, (0, 0, 0, 0))
        color_table.SetColorEntry(1, (0, 0, 255, 100))
    elif ctype == 'GSW':
        color_table.SetColorEntry(0, (0, 0, 0, 0))
        color_table.SetColorEntry(1, (0, 0, 0, 0))
        color_table.SetColorEntry(2, (176, 224, 230, 100))

    return color_table
def reclassify():
    directory_path = str(input('enter diectory path'))
    os.chdir(directory_path)
    fileName = str(input("file name"))
    datasource = gdal.Open(fileName)
    fileformat = datasource.GetDriver().ShortName
    driver = gdal.GetDriverByName(str(fileformat))
    ds = driver.Create('dem_class4.tif', datasource.RasterXSize,
                       datasource.RasterYSize, 1, gdal.GDT_Byte)
    band = ds.GetRasterBand(1)
    colors = gdal.ColorTable()
    Number_of_class = int(input("enter number of class for classification"))
    for i in range(Number_of_class):
        color_tuple = []
        Number_of_class = int(input("enter number of color to be entered"))
        for j in range(Number_of_class):
            color = int(input("enter color number"))
            color_tuple.append(color)
        colors.SetColorEntry(i, tuple(color_tuple))
    band.SetRasterColorTable(colors)
    band.SetRasterColorInterpretation(gdal.GCI_PaletteIndex)
    del band, ds
Ejemplo n.º 16
0
def test_gdal_translate_19():
    if test_cli_utilities.get_gdal_translate_path() is None:
        return 'skip'

    ds = gdal.GetDriverByName('GTiff').Create(
        'tmp/test_gdal_translate_19_src.tif', 1, 1, 2)
    ct = gdal.ColorTable()
    ct.SetColorEntry(127, (1, 2, 3, 255))
    ds.GetRasterBand(1).SetRasterColorTable(ct)
    ds.GetRasterBand(1).Fill(127)
    ds.GetRasterBand(2).Fill(250)
    ds = None

    gdaltest.runexternal(
        test_cli_utilities.get_gdal_translate_path() +
        ' -expand rgba tmp/test_gdal_translate_19_src.tif tmp/test_gdal_translate_19_dst.tif'
    )

    ds = gdal.Open('tmp/test_gdal_translate_19_dst.tif')
    if ds is None:
        return 'fail'

    if ds.GetRasterBand(1).Checksum() != 1:
        gdaltest.post_reason('Bad checksum for band 1')
        return 'fail'
    if ds.GetRasterBand(2).Checksum() != 2:
        gdaltest.post_reason('Bad checksum for band 2')
        return 'fail'
    if ds.GetRasterBand(3).Checksum() != 3:
        gdaltest.post_reason('Bad checksum for band 3')
        return 'fail'
    if ds.GetRasterBand(4).Checksum() != 250 % 7:
        gdaltest.post_reason('Bad checksum for band 4')
        return 'fail'

    ds = None

    return 'success'
Ejemplo n.º 17
0
def ehdr_4():

    drv = gdal.GetDriverByName('EHdr')
    ds = drv.Create('tmp/test_4.bil', 200, 100, 1, gdal.GDT_Byte)

    raw_data = array.array('h', list(range(200))).tostring()

    for line in range(100):
        ds.WriteRaster(0, line, 200, 1, raw_data, buf_type=gdal.GDT_Int16)

    ct = gdal.ColorTable()
    ct.SetColorEntry(0, (255, 255, 255, 255))
    ct.SetColorEntry(1, (255, 255, 0, 255))
    ct.SetColorEntry(2, (255, 0, 255, 255))
    ct.SetColorEntry(3, (0, 255, 255, 255))

    ds.GetRasterBand(1).SetRasterColorTable(ct)

    ds.GetRasterBand(1).SetNoDataValue(17)

    ds = None

    return 'success'
Ejemplo n.º 18
0
def CreateColorTable(fileLUT):
    """
    IN :
        fileLUT [string] : path to the color file table
            ex : for a table containing 3 classes ("8","90","21"), "8" must be represent in red, "90" in green, "21" in blue
                cat /path/to/myColorTable.csv
                8 255 0 0
                90 0 255 0
                21 0 255 0
    OUT :
        ct [gdalColorTable]
    """
    filein = open(fileLUT)
    ct = gdal.ColorTable()
    for line in filein:
        entry = line
        classID = entry.split(" ")
        if len(classID) < 4:
            continue
        codeColor = [int(i) for i in (classID[1:4])]
        ct.SetColorEntry(int(classID[0]), tuple(codeColor))
    filein.close()
    return ct
Ejemplo n.º 19
0
def format_colortable(name,
                      vmin=0.,
                      vmax=1.,
                      vmin_pal=0.,
                      vmax_pal=1.,
                      index_min=0,
                      index_max=254,
                      index_nodata=255):
    """
    """
    colormap, nodata_color = load_colormap(name)
    norm_min = (vmin - vmin_pal) / float((vmax_pal - vmin_pal))
    norm_max = (vmax - vmin_pal) / float((vmax_pal - vmin_pal))
    ncols = int(index_max) - int(index_min) + 1
    colors = colormap(np.linspace(norm_min, norm_max, num=ncols))
    colors = np.round(colors * 255)
    entries = [(int(c[0]), int(c[1]), int(c[2])) for c in colors]
    colortable = gdal.ColorTable()
    for index in range(int(index_min), int(index_max) + 1):
        colortable.SetColorEntry(index, entries[index - int(index_min)])
    if index_nodata != None:
        colortable.SetColorEntry(index_nodata, nodata_color)
    return colortable
Ejemplo n.º 20
0
def colortable(ctype):

    color_table = gdal.ColorTable()

    if ctype == 'binary':
        # Some examples / last value is alpha (transparency). See http://www.gdal.org/structGDALColorEntry.html
        # and https://gis.stackexchange.com/questions/158195/python-gdal-create-geotiff-from-array-with-colormapping
        color_table.SetColorEntry(0, (0, 0, 0, 0))
        color_table.SetColorEntry(1, (255, 255, 255, 100))
    elif ctype == 'skel':
        color_table.SetColorEntry(0, (0, 0, 0, 0))
        color_table.SetColorEntry(1, (255, 0, 255, 100))
    elif ctype == 'mask':
        color_table.SetColorEntry(0, (0, 0, 0, 0))
        color_table.SetColorEntry(1, (0, 128, 0, 100))
    elif ctype == 'tile':
        color_table.SetColorEntry(0, (0, 0, 0, 0))
        color_table.SetColorEntry(1, (0, 0, 255, 100))
    elif ctype == 'JRCmo':
        color_table.SetColorEntry(0, (0, 0, 0, 0))
        color_table.SetColorEntry(1, (0, 0, 0, 0))
        color_table.SetColorEntry(2, (176, 224, 230, 100))

    return color_table
Ejemplo n.º 21
0
def combine_class(band):

    classes = band.GetCategoryNames()
    color_int = band.GetRasterColorInterpretation()

    attr_table = band.GetDefaultRAT()
    cols = attr_table.GetColumnCount()
    rows = attr_table.GetRowCount()
    cat_names = band.GetCategoryNames()
    #print cat_names
    color_interpretation = band.GetRasterColorInterpretation()
    print 'band color interpretation is %s' % color_interpretation
    band.SetRasterColorInterpretation(GCI_PaletteIndex)
    print 'band color interpretation is set to %s' % color_int
    """The output class image from ArcGIS has no color table"""
    #color_table = band.GetColorTable()
    #print color_table
    #print color_table.GetColorEntryCount()

    # initialize new color table for image
    color_table = gdal.ColorTable(GPI_RGB)
    print 'color table palette interpretation is %s' % color_table.GetPaletteInterpretation(
    )
    #print 'color table entry count is %s' % color_table.GetCount()

    color_entry = color_table.SetColorEntry(color_table, 255)
    #print color_entry(255,255,255,0)

    # prepare information from the raster attribute table for setting up color table
    class_value = []
    class_name = []
    rgb = []  # to be used for setting up class colors

    # get column headings
    #for col in range(cols):
    #print attr_table.GetNameOfCol(col)

    #print "\n"

    for row in range(rows):
        class_value.append(attr_table.GetValueAsString(row, 2))
        class_name.append(attr_table.GetValueAsString(row, 3))
        rgb.append((attr_table.GetValueAsString(row, 4),
                    attr_table.GetValueAsString(row, 5),
                    attr_table.GetValueAsString(row, 6)))

    #print class_name

    band.SetRasterCategoryNames(class_name)

    #for i in range(len(rgb)):
    # set color palette
    #print rgb[i][0]
    #red = color_table.GetColorEntry(int(rgb[i][0]))
    #print red
    #green = color_table.GetPaletteInterpretation(rgb[i][1])
    #blue = color_table.GetPaletteInterpretation(rgb[i][2])

    #color_table.SetColorEntry(i, red, green, blue)

    #print class_value[i], class_name[i], rgb[i]

    #band.FlushCache()

    return
Ejemplo n.º 22
0
def gdal_api_proxy_sub():

    src_ds = gdal.Open('data/byte.tif')
    src_cs = src_ds.GetRasterBand(1).Checksum()
    src_gt = src_ds.GetGeoTransform()
    src_prj = src_ds.GetProjectionRef()
    src_data = src_ds.ReadRaster(0, 0, 20, 20)
    src_md = src_ds.GetMetadata()
    src_ds = None

    drv = gdal.IdentifyDriver('data/byte.tif')
    if drv.GetDescription() != 'API_PROXY':
        gdaltest.post_reason('fail')
        return 'fail'

    ds = gdal.GetDriverByName('GTiff').Create('tmp/byte.tif', 1, 1, 3)
    ds = None

    src_ds = gdal.Open('data/byte.tif')
    ds = gdal.GetDriverByName('GTiff').CreateCopy('tmp/byte.tif',
                                                  src_ds,
                                                  options=['TILED=YES'])
    got_cs = ds.GetRasterBand(1).Checksum()
    if src_cs != got_cs:
        gdaltest.post_reason('fail')
        return 'fail'
    ds = None

    ds = gdal.Open('tmp/byte.tif', gdal.GA_Update)

    ds.SetGeoTransform([1, 2, 3, 4, 5, 6])
    got_gt = ds.GetGeoTransform()
    if src_gt == got_gt:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.SetGeoTransform(src_gt)
    got_gt = ds.GetGeoTransform()
    if src_gt != got_gt:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetGCPCount() != 0:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetGCPProjection() != '':
        print(ds.GetGCPProjection())
        gdaltest.post_reason('fail')
        return 'fail'

    if len(ds.GetGCPs()) != 0:
        gdaltest.post_reason('fail')
        return 'fail'

    gcps = [gdal.GCP(0, 1, 2, 3, 4)]
    ds.SetGCPs(gcps, "foo")

    got_gcps = ds.GetGCPs()
    if len(got_gcps) != 1:
        gdaltest.post_reason('fail')
        return 'fail'

    if got_gcps[0].GCPLine != gcps[0].GCPLine or  \
       got_gcps[0].GCPPixel != gcps[0].GCPPixel or  \
       got_gcps[0].GCPX != gcps[0].GCPX or \
       got_gcps[0].GCPY != gcps[0].GCPY:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetGCPProjection() != 'foo':
        gdaltest.post_reason('fail')
        print(ds.GetGCPProjection())
        return 'fail'

    ds.SetGCPs([], "")

    if len(ds.GetGCPs()) != 0:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.SetProjection('')
    got_prj = ds.GetProjectionRef()
    if src_prj == got_prj:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.SetProjection(src_prj)
    got_prj = ds.GetProjectionRef()
    if src_prj != got_prj:
        gdaltest.post_reason('fail')
        print(src_prj)
        print(got_prj)
        return 'fail'

    ds.GetRasterBand(1).Fill(0)
    got_cs = ds.GetRasterBand(1).Checksum()
    if 0 != got_cs:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).WriteRaster(0, 0, 20, 20, src_data)
    got_cs = ds.GetRasterBand(1).Checksum()
    if src_cs != got_cs:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).Fill(0)
    got_cs = ds.GetRasterBand(1).Checksum()
    if 0 != got_cs:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.WriteRaster(0, 0, 20, 20, src_data)
    got_cs = ds.GetRasterBand(1).Checksum()
    if src_cs != got_cs:
        gdaltest.post_reason('fail')
        return 'fail'

    # Not bound to SWIG
    # ds.AdviseRead(0,0,20,20,20,20)

    got_data = ds.ReadRaster(0, 0, 20, 20)
    if src_data != got_data:
        gdaltest.post_reason('fail')
        return 'fail'

    got_data = ds.GetRasterBand(1).ReadRaster(0, 0, 20, 20)
    if src_data != got_data:
        gdaltest.post_reason('fail')
        return 'fail'

    got_data_weird_spacing = ds.ReadRaster(0,
                                           0,
                                           20,
                                           20,
                                           buf_pixel_space=1,
                                           buf_line_space=32)
    if len(got_data_weird_spacing) != 32 * (20 - 1) + 20:
        gdaltest.post_reason('fail')
        print(len(got_data_weird_spacing))
        return 'fail'

    if got_data[20:20 + 20] != got_data_weird_spacing[32:32 + 20]:
        gdaltest.post_reason('fail')
        return 'fail'

    got_data_weird_spacing = ds.GetRasterBand(1).ReadRaster(0,
                                                            0,
                                                            20,
                                                            20,
                                                            buf_pixel_space=1,
                                                            buf_line_space=32)
    if len(got_data_weird_spacing) != 32 * (20 - 1) + 20:
        gdaltest.post_reason('fail')
        print(len(got_data_weird_spacing))
        return 'fail'

    if got_data[20:20 + 20] != got_data_weird_spacing[32:32 + 20]:
        gdaltest.post_reason('fail')
        return 'fail'

    got_block = ds.GetRasterBand(1).ReadBlock(0, 0)
    if len(got_block) != 256 * 256:
        gdaltest.post_reason('fail')
        return 'fail'

    if got_data[20:20 + 20] != got_block[256:256 + 20]:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.FlushCache()
    ds.GetRasterBand(1).FlushCache()

    got_data = ds.GetRasterBand(1).ReadRaster(0, 0, 20, 20)
    if src_data != got_data:
        gdaltest.post_reason('fail')
        return 'fail'

    if len(ds.GetFileList()) != 1:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.AddBand(gdal.GDT_Byte) == 0:
        gdaltest.post_reason('fail')
        return 'fail'

    got_md = ds.GetMetadata()
    if src_md != got_md:
        gdaltest.post_reason('fail')
        print(src_md)
        print(got_md)
        return 'fail'

    if ds.GetMetadataItem('AREA_OR_POINT') != 'Area':
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetMetadataItem('foo') is not None:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.SetMetadataItem('foo', 'bar')
    if ds.GetMetadataItem('foo') != 'bar':
        gdaltest.post_reason('fail')
        return 'fail'

    ds.SetMetadata({'foo': 'baz'}, 'OTHER')
    if ds.GetMetadataItem('foo', 'OTHER') != 'baz':
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).SetMetadata({'foo': 'baw'}, 'OTHER')
    if ds.GetRasterBand(1).GetMetadataItem('foo', 'OTHER') != 'baw':
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetMetadataItem('INTERLEAVE', 'IMAGE_STRUCTURE') != 'BAND':
        gdaltest.post_reason('fail')
        return 'fail'

    if len(ds.GetRasterBand(1).GetMetadata()) != 0:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).GetMetadataItem('foo') is not None:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).SetMetadataItem('foo', 'baz')
    if ds.GetRasterBand(1).GetMetadataItem('foo') != 'baz':
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).SetMetadata({'foo': 'baw'})
    if ds.GetRasterBand(1).GetMetadataItem('foo') != 'baw':
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).GetColorInterpretation() != gdal.GCI_GrayIndex:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).SetColorInterpretation(gdal.GCI_Undefined)

    ct = ds.GetRasterBand(1).GetColorTable()
    if ct is not None:
        gdaltest.post_reason('fail')
        return 'fail'

    ct = gdal.ColorTable()
    ct.SetColorEntry(0, (1, 2, 3))
    if ds.GetRasterBand(1).SetColorTable(ct) != 0:
        gdaltest.post_reason('fail')
        return 'fail'

    ct = ds.GetRasterBand(1).GetColorTable()
    if ct is None:
        gdaltest.post_reason('fail')
        return 'fail'
    if ct.GetColorEntry(0) != (1, 2, 3, 255):
        gdaltest.post_reason('fail')
        return 'fail'

    ct = ds.GetRasterBand(1).GetColorTable()
    if ct is None:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).SetColorTable(None) != 0:
        gdaltest.post_reason('fail')
        return 'fail'

    ct = ds.GetRasterBand(1).GetColorTable()
    if ct is not None:
        gdaltest.post_reason('fail')
        return 'fail'

    rat = ds.GetRasterBand(1).GetDefaultRAT()
    if rat is not None:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).SetDefaultRAT(None) != 0:
        gdaltest.post_reason('fail')
        return 'fail'

    ref_rat = gdal.RasterAttributeTable()
    if ds.GetRasterBand(1).SetDefaultRAT(ref_rat) != 0:
        gdaltest.post_reason('fail')
        return 'fail'

    rat = ds.GetRasterBand(1).GetDefaultRAT()
    if rat is None:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).SetDefaultRAT(None) != 0:
        gdaltest.post_reason('fail')
        return 'fail'

    rat = ds.GetRasterBand(1).GetDefaultRAT()
    if rat is not None:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).GetMinimum() is not None:
        gdaltest.post_reason('fail')
        return 'fail'

    got_stats = ds.GetRasterBand(1).GetStatistics(0, 0)
    if got_stats[3] >= 0.0:
        gdaltest.post_reason('fail')
        return 'fail'

    got_stats = ds.GetRasterBand(1).GetStatistics(1, 1)
    if got_stats[0] != 74.0:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).GetMinimum() != 74.0:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).GetMaximum() != 255.0:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).SetStatistics(1, 2, 3, 4)
    got_stats = ds.GetRasterBand(1).GetStatistics(1, 1)
    if got_stats != [1, 2, 3, 4]:
        print(got_stats)
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).ComputeStatistics(0)
    got_stats = ds.GetRasterBand(1).GetStatistics(1, 1)
    if got_stats[0] != 74.0:
        gdaltest.post_reason('fail')
        return 'fail'

    minmax = ds.GetRasterBand(1).ComputeRasterMinMax()
    if minmax != (74.0, 255.0):
        gdaltest.post_reason('fail')
        print(minmax)
        return 'fail'

    if ds.GetRasterBand(1).GetOffset() != 0.0:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).GetScale() != 1.0:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).SetOffset(10.0)
    if ds.GetRasterBand(1).GetOffset() != 10.0:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).SetScale(2.0)
    if ds.GetRasterBand(1).GetScale() != 2.0:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.BuildOverviews('NEAR', [2])
    if ds.GetRasterBand(1).GetOverviewCount() != 1:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).GetOverview(-1) is not None:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).GetOverview(0) is None:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).GetOverview(0) is None:
        gdaltest.post_reason('fail')
        return 'fail'

    got_hist = ds.GetRasterBand(1).GetHistogram()
    if len(got_hist) != 256:
        gdaltest.post_reason('fail')
        return 'fail'

    (minval, maxval, nitems,
     got_hist2) = ds.GetRasterBand(1).GetDefaultHistogram()
    if minval != -0.5:
        gdaltest.post_reason('fail')
        return 'fail'
    if maxval != 255.5:
        gdaltest.post_reason('fail')
        return 'fail'
    if nitems != 256:
        gdaltest.post_reason('fail')
        return 'fail'
    if got_hist != got_hist2:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).SetDefaultHistogram(1, 2, [3])
    (minval, maxval, nitems,
     got_hist3) = ds.GetRasterBand(1).GetDefaultHistogram()
    if minval != 1:
        gdaltest.post_reason('fail')
        return 'fail'
    if maxval != 2:
        gdaltest.post_reason('fail')
        return 'fail'
    if nitems != 1:
        gdaltest.post_reason('fail')
        return 'fail'
    if got_hist3[0] != 3:
        gdaltest.post_reason('fail')
        return 'fail'

    got_nodatavalue = ds.GetRasterBand(1).GetNoDataValue()
    if got_nodatavalue is not None:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).SetNoDataValue(123)
    got_nodatavalue = ds.GetRasterBand(1).GetNoDataValue()
    if got_nodatavalue != 123:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).GetMaskFlags() != 8:
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).GetMaskBand() is None:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.CreateMaskBand(0)

    if ds.GetRasterBand(1).GetMaskFlags() != 2:
        print(ds.GetRasterBand(1).GetMaskFlags())
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).GetMaskBand() is None:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).CreateMaskBand(0)

    if ds.GetRasterBand(1).HasArbitraryOverviews() != 0:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).SetUnitType('foo')
    if ds.GetRasterBand(1).GetUnitType() != 'foo':
        gdaltest.post_reason('fail')
        return 'fail'

    if ds.GetRasterBand(1).GetCategoryNames() is not None:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).SetCategoryNames(['foo'])
    if ds.GetRasterBand(1).GetCategoryNames() != ['foo']:
        gdaltest.post_reason('fail')
        return 'fail'

    ds.GetRasterBand(1).SetDescription('bar')

    ds = None

    gdal.GetDriverByName('GTiff').Delete('tmp/byte.tif')

    return 'success'
def PiNo_classifier(infile, outfile, sat, detect_clouds, sun_azi,
                    fnorm_delta_val):
    try:
        print "INTO pino classifier"
        start = time.time()
        print infile
        src_ds = gdal.Open(infile)
        num_bands = src_ds.RasterCount

        print infile
        print "Bands:"
        print num_bands

        if (sat == "S2A_L1C" and (not num_bands == 13)):
            src_ds = None
            print "Input image does not have 13 bands" + "<br>"
            return (-1)

        DataType = src_ds.GetRasterBand(1).DataType
        DataType = gdal.GetDataTypeName(DataType)
        print "Type: " + DataType

        if (DataType == "Byte"):
            TOTmem = src_ds.RasterXSize * src_ds.RasterYSize * 8 / 1024 * 6 * 20

        elif (DataType == "UInt16"):
            TOTmem = src_ds.RasterXSize * src_ds.RasterYSize * 16 / 1024 * 6 * 20

        else:
            TOTmem = src_ds.RasterXSize * src_ds.RasterYSize * 32 / 1024 * 6 * 20
            print TOTmem

        res = UnixMemory()
        FREEmem = res['free']
        print FREEmem

        if (FREEmem < TOTmem + (TOTmem / 100 * 20)):
            print "Processing using tiling ..... " + "<br>"
            ratio = int(round(TOTmem * Memory_Buffer / FREEmem * 1., ))
        else:
            ratio = 1

        print "Ratio: " + str(ratio) + "<br>"

        MaxX = src_ds.RasterXSize
        MaxY = src_ds.RasterYSize

        LenX = int(src_ds.RasterXSize / ratio)
        LenY = int(src_ds.RasterYSize / ratio)

        #print MaxX,MaxY
        #print LenX,LenY

        Xval = []
        Yval = []

        for x in range(1, ratio):
            Xval.append(LenX * x)

        Xval.append(MaxX)

        for x in range(1, ratio):
            Yval.append(LenY * x)

        Yval.append(MaxY)

        #print Xval
        #print Yval

        driver = gdal.GetDriverByName("GTiff")
        print outfile

        dst_ds = driver.Create(outfile,
                               src_ds.RasterXSize,
                               src_ds.RasterYSize,
                               1,
                               gdal.GDT_Byte,
                               options=['COMPRESS=LZW'])
        dst_ds.GetRasterBand(1).SetRasterColorInterpretation(
            gdal.GCI_PaletteIndex)
        c = gdal.ColorTable()
        ctable = [[0, (0, 0, 0)], [1, (255, 255, 255)], [2, (192, 242, 255)],
                  [3, (1, 255, 255)], [4, (0, 0, 0)], [5, (1, 1, 255)],
                  [6, (1, 123, 255)], [7,
                                       (110, 150, 255)], [8, (168, 180, 255)],
                  [9, (160, 255, 90)], [10, (1, 80, 1)], [11, (12, 113, 1)],
                  [12, (1, 155, 1)],
                  [13, (100, 190, 90)], [14, (146, 255, 165)], [15, (0, 0, 0)],
                  [16, (210, 255, 153)], [17, (0, 0, 0)], [18, (0, 0, 0)],
                  [19, (0, 0, 0)], [20, (0, 0, 0)], [21, (237, 255, 193)],
                  [22, (200, 230, 200)], [23, (0, 0, 0)], [24, (0, 0, 0)],
                  [25, (0, 0, 0)], [26, (0, 0, 0)], [27, (0, 0, 0)],
                  [28, (0, 0, 0)], [29, (0, 0, 0)], [30, (200, 200, 150)],
                  [31, (227, 225, 170)], [32, (0, 0, 0)], [33, (0, 0, 0)],
                  [34, (255, 225, 255)], [35, (140, 5, 190)],
                  [36, (255, 1, 1)], [37, (0, 0, 0)], [38, (0, 0, 0)],
                  [39, (0, 0, 0)], [40, (20, 40, 10)], [41, (145, 1, 110)],
                  [42, (100, 100, 100)]]
        for cid in range(0, 43):
            c.SetColorEntry(cid, ctable[cid][1])

        dst_ds.GetRasterBand(1).SetColorTable(c)
        dst_ds.SetGeoTransform(src_ds.GetGeoTransform())
        dst_ds.SetProjection(src_ds.GetProjectionRef())

        print 'Setting out file ....'
        OUTCLASS = dst_ds.GetRasterBand(1).ReadAsArray(0, 0,
                                                       dst_ds.RasterXSize,
                                                       dst_ds.RasterYSize)
        print 'Processing <br>'
        mytry = 0

        #try:
        if (1 == 1):
            MinX = 0
            MinY = 0
            for x in range(0, len(Xval)):
                for y in range(0, len(Yval)):

                    MaxX = Xval[x]
                    MaxY = Yval[y]
                    OUTCLASS[MinY:MaxY, MinX:MaxX] = classify_tile(
                        src_ds, MinX, MinY, MaxX, MaxY, sat, fnorm_delta_val,
                        DataType)
                    MinY = MaxY
                MinX = MaxX
                MinY = 0

            dst_ds.GetRasterBand(1).WriteArray(OUTCLASS)

            if (detect_clouds == 0):
                OUTCLASS[(OUTCLASS == 1)] = 34
                OUTCLASS[(OUTCLASS == 2)] = 34

            else:
                try:
                    dst_ds.GetRasterBand(1).WriteArray(
                        OUTCLASS
                    )  # save temporary out so class is saved even on mask failure

                    filter = numpy.ndarray(shape=(17, 17), dtype=bool)
                    filter[:] = False
                    filter[0, 8] = True
                    filter[1, 7:9] = True
                    filter[2, 6:9] = True
                    filter[3, 5:10] = True
                    filter[4, 4:11] = True
                    filter[5, 3:12] = True
                    filter[6, 2:14] = True
                    filter[7, 2:14] = True
                    filter[8, 0:16] = True
                    filter[9, 2:14] = True
                    filter[10, 2:14] = True
                    filter[11, 3:12] = True
                    filter[12, 4:11] = True
                    filter[13, 5:10] = True
                    filter[14, 6:9] = True
                    filter[15, 7:9] = True
                    filter[16, 8] = True

                    BOOL_MATRIX = numpy.zeros(
                        (dst_ds.RasterXSize, dst_ds.RasterYSize)).astype(bool)
                    BOOL_MATRIX = (OUTCLASS == 1) + (OUTCLASS == 2)
                    #OUTCLASS[mmorph.close(BOOL_MATRIX,filter)]=1  # 3x3

                    #------------------------------------------------------------------------------------------
                    #-------------------------USE 3D filter for CL - SH detection -----------------------------
                    #------------------------------------------------------------------------------------------
                    if (int(sun_azi) > 0):

                        sft = getend([100, 100], 270 - int(sun_azi), 20)
                        shiftX = 100 - int(sft[0])
                        shiftY = 100 - int(sft[1])

                        BOOL_MATRIX = numpy.roll(BOOL_MATRIX, shiftX, axis=0)
                        BOOL_MATRIX = numpy.roll(BOOL_MATRIX, shiftY, axis=1)
                        print "Cloud masking step 1" + "<br>"
                        SHDW_MATRIX = ((OUTCLASS == 10) + (OUTCLASS == 40) +
                                       (OUTCLASS == 41) + (OUTCLASS == 42) +
                                       (OUTCLASS == 34) + (OUTCLASS == 35) +
                                       (OUTCLASS == 36) + (OUTCLASS == 4) +
                                       (OUTCLASS == 5) + (OUTCLASS == 6) +
                                       (OUTCLASS == 7) +
                                       (OUTCLASS == 8)).astype(bool)

                        SHDW_MATRIX *= BOOL_MATRIX
                        print "Cloud masking step 2" + "<br>"

                        #OUTCLASS[mmorph.close(SHDW_MATRIX,filter)]=42
                        print "Cloud masking step 3" + "<br>"
                except:
                    print "Image is TOO BIG for morphological filters <br>"

            dst_ds.GetRasterBand(1).WriteArray(OUTCLASS)

            #close roperly the dataset
            dst_ds = None
            src_ds = None

            print "Execution time: " + str(time.time() - start) + "<br>"

            return 'Complete' + "<br>"

    except Exception, e:
        print "Error  :"
        print str(e)
Ejemplo n.º 24
0
if src_ds.RasterCount < 3:
    print '%s has %d bands, need 3 for inputs red, green and blue.' \
          % src_ds.RasterCount
    sys.exit(1)

# Ensure we recognise the driver.

dst_driver = gdal.GetDriverByName(format)
if dst_driver is None:
    print '"%s" driver not registered.' % format
    sys.exit(1)

# Generate the median cut PCT

ct = gdal.ColorTable()

err = gdal.ComputeMedianCutPCT(src_ds.GetRasterBand(1),
                               src_ds.GetRasterBand(2),
                               src_ds.GetRasterBand(3),
                               color_count,
                               ct,
                               callback=gdal.TermProgress,
                               callback_data='Generate PCT')

# Create the working file.  We have to use TIFF since there are few formats
# that allow setting the color table after creation.

if format == 'GTiff':
    tif_filename = dst_filename
else:
import gdal, glob

# specify GeoTIFF file name, open it using GDAL and get the first band
listFiles=glob.glob('/Users/Andy/Documents/Rwanda/Data/TropWet_Outputs/TW_v7.2/Classified*')

for fn in listFiles:

#fn = '/Users/Andy/Documents/Rwanda/Data/TropWet_Outputs/TW_v7.2/Classified_Output_April_to_June_2016_to_2020.tif'
	ds = gdal.Open(fn, 1)
	band = ds.GetRasterBand(1)

	# create color table
	colors = gdal.ColorTable()

	# set color for each value
	colors.SetColorEntry(0, (209, 209, 204))
	colors.SetColorEntry(1, (69, 174, 144))
	colors.SetColorEntry(2, (221, 154, 162))
	colors.SetColorEntry(3, (229, 206, 113))
	colors.SetColorEntry(4, (34, 80, 202))
	colors.SetColorEntry(5, (50, 133, 54))
	colors.SetColorEntry(6, (203, 191, 124))


	# set color table and color interpretation
	band.SetRasterColorTable(colors)
	band.SetRasterColorInterpretation(gdal.GCI_PaletteIndex)

	# close and save file
	del band, ds