def annotate_ecoregions(ecoregion_uri, table_uri):
    lookup_table = raster_utils.get_lookup_from_csv(table_uri, 'region')
    
    print 'generating report'
    esri_driver = ogr.GetDriverByName('ESRI Shapefile')

    original_datasource = ogr.Open(ecoregion_uri)
    updated_datasource_uri = os.path.join(os.path.dirname(ecoregion_uri), 'annotated_ecoregions_2.shp')
    #If there is already an existing shapefile with the same name and path, delete it
    #Copy the input shapefile into the designated output floder
    if os.path.isfile(updated_datasource_uri):
        os.remove(updated_datasource_uri)
    datasource_copy = esri_driver.CopyDataSource(original_datasource, updated_datasource_uri)
    layer = datasource_copy.GetLayer()

    new_field_names = [('Magnitude', 'magnitude'), ('A(80)', 'A80'), ('A(90)', 'A90'), ('A(95)', 'A95')]


    for table_field, field_name in new_field_names:
        field_def = ogr.FieldDefn(field_name, ogr.OFTReal)
        layer.CreateField(field_def)

    for feature_id in xrange(layer.GetFeatureCount()):
        feature = layer.GetFeature(feature_id)

        feature_eco_name = raster_utils._smart_cast(feature.GetField('ECO_NAME'))
        
        for table_field, field_name in new_field_names:
            try:
                value = lookup_table[feature_eco_name][table_field]
                if feature_eco_name == 'Balsas Dry Forests':
                    print feature_eco_name, table_field, value
                feature.SetField(field_name, float(value))
            except KeyError as e:
                if feature_eco_name == 'Balsas Dry Forests':
                    print e
                feature.SetField(field_name, -1.0)
            except TypeError:
                if feature_eco_name == 'Balsas Dry Forests':
                    print e
                feature.SetField(field_name, -1.0)

        #Save back to datasource
        layer.SetFeature(feature)
def average_layers():

    base_table_uri = "C:/Users/rich/Desktop/all_grid_results_100km_clean_v2.csv"
    base_table_file = open(base_table_uri, 'rU')
    table_header = base_table_file.readline()

    #need to mask the average layers to the biomass regions

    giant_layer_uri = "C:/Users/rich/Desktop/average_layers_projected/giant_layer.tif"

    af_uri = "C:/Users/rich/Desktop/af_biov2ct1.tif"
    am_uri = "C:/Users/rich/Desktop/am_biov2ct1.tif"
    as_uri = "C:/Users/rich/Desktop/as_biov2ct1.tif"
    cell_size = raster_utils.get_cell_size_from_uri(am_uri)
    #raster_utils.vectorize_datasets(
    #    [af_uri, am_uri, as_uri], lambda x,y,z: x+y+z, giant_layer_uri, gdal.GDT_Float32,
    #    -1, cell_size, 'union', vectorize_op=False)

    table_uri = base_table_uri
    table_file = open(table_uri, 'rU')
    
    table_header = table_file.readline().rstrip()


    lookup_table = raster_utils.get_lookup_from_csv(table_uri, 'ID100km')

    out_table_uri =  "C:/Users/rich/Desktop/all_grid_results_100km_human_elevation.csv"
    out_table_file = codecs.open(out_table_uri, 'w', 'utf-8')

    average_raster_list = [
        ("C:/Users/rich/Desktop/average_layers_projected/lighted_area_luminosity.tif", 'Lighted area density'),
        ("C:/Users/rich/Desktop/average_layers_projected/fi_average.tif", 'Fire densities'),
        ("C:/Users/rich/Desktop/average_layers_projected/glbctd1t0503m.tif", 'FAO_Cattle'),
        ("C:/Users/rich/Desktop/average_layers_projected/glbgtd1t0503m.tif", 'FAO_Goat'),
        ("C:/Users/rich/Desktop/average_layers_projected/glbpgd1t0503m.tif", 'FAO_Pig'),
        ("C:/Users/rich/Desktop/average_layers_projected/glbshd1t0503m.tif", 'FAO_Sheep'),
        ("C:/Users/rich/Desktop/average_layers_projected/glds00ag.tif", 'Human population density AG'),
        ("C:/Users/rich/Desktop/average_layers_projected/glds00g.tif", 'Human population density G'),
        ('C:/Users/rich/Desktop/average_layers_projected/anthrome_11.tif', '"11: Urban, Dense settlement"'),
        ('C:/Users/rich/Desktop/average_layers_projected/anthrome_12.tif', '"12: Dense settlements, Dense settlements"'),
        ('C:/Users/rich/Desktop/average_layers_projected/anthrome_22.tif', '"22: Irrigated villages, Villages"'),
        ('C:/Users/rich/Desktop/average_layers_projected/anthrome_23.tif', '"23: Cropped & pastoral villages, Villages"'),
        ('C:/Users/rich/Desktop/average_layers_projected/anthrome_24.tif', '"24: Pastoral villages, Villages"'),
        ('C:/Users/rich/Desktop/average_layers_projected/anthrome_25.tif', '"25: Rainfed villages, Villages"'),
        ('C:/Users/rich/Desktop/average_layers_projected/anthrome_26.tif', '"26: Rainfed mosaic villages, Villages"'),
        ('C:/Users/rich/Desktop/average_layers_projected/anthrome_31.tif', '"31: Residential irrigated cropland, Croplands"'),
        ('C:/Users/rich/Desktop/average_layers_projected/anthrome_32.tif', '"32: Residential rainfed mosaic, Croplands"'),
        ('C:/Users/rich/Desktop/average_layers_projected/anthrome_33.tif', '"33: Populated irrigated cropland,   Croplands"'),
        ('C:/Users/rich/Desktop/average_layers_projected/anthrome_34.tif', '"34: Populated rainfed cropland, Croplands"'),
        ('C:/Users/rich/Desktop/average_layers_projected/anthrome_35.tif', '"35: Remote croplands, Croplands"'),
        ('C:/Users/rich/Desktop/average_layers_projected/anthrome_41.tif', '"41: Residential rangelands, Rangelands"'),
        ('C:/Users/rich/Desktop/average_layers_projected/anthrome_42.tif', '"42: Populated rangelands, Rangelands"'),
        ('C:/Users/rich/Desktop/average_layers_projected/anthrome_43.tif', '"43: Remote rangelands, Rangelands"'),
        ('C:/Users/rich/Desktop/average_layers_projected/anthrome_51.tif', '"51: Populated forests, Forested"'),
        ('C:/Users/rich/Desktop/average_layers_projected/anthrome_52.tif', '"52: Remote forests, Forested"'),
        ('C:/Users/rich/Desktop/average_layers_projected/anthrome_61.tif', '"61: Wild forests, Wildlands"'),
        ('C:/Users/rich/Desktop/average_layers_projected/anthrome_62.tif', '"62: Sparse trees, Wildlands"'),
        ('C:/Users/rich/Desktop/average_layers_projected/anthrome_63.tif', '"63: Barren, Wildlands"'),
        ("C:/Users/rich/Desktop/average_layers_projected/5km_global_pantropic_dem.tif", '"Average Elevation"'),
        ]

    clipped_raster_list = []


    for average_raster_uri, header in average_raster_list:
        print 'clipping ' + average_raster_uri
        clipped_raster_uri = os.path.join(os.path.dirname(average_raster_uri), 'temp', os.path.basename(average_raster_uri))
        cell_size = raster_utils.get_cell_size_from_uri(average_raster_uri)
        raster_utils.vectorize_datasets(
            [average_raster_uri, giant_layer_uri], lambda x,y: x, clipped_raster_uri, gdal.GDT_Float32,
            -1, cell_size, 'intersection', vectorize_op=False)
        clipped_raster_list.append((clipped_raster_uri, header))

    dataset_list = [gdal.Open(uri) for uri, label in clipped_raster_list]
    band_list = [ds.GetRasterBand(1) for ds in dataset_list]
    nodata_list = [band.GetNoDataValue() for band in band_list]

    extended_table_headers = ','.join([header for _, header in average_raster_list])


    def write_to_file(value):
        try:
            out_table_file.write(value)
        except UnicodeDecodeError as e:
            out_table_file.write(value.decode('latin-1'))

    write_to_file(table_header + ',' + extended_table_headers + '\n')
    #print table_header + ',' + extended_table_headers

    for line in table_file:
        split_line = line.rstrip().split(',')
        grid_id = split_line[2]
    #for grid_id in lookup_table:
        try:
            split_grid_id = grid_id.split('-')
            grid_row_index, grid_col_index = map(int, split_grid_id)
        except ValueError as e:
            month_to_number = {
                'Jan': 1,
                'Feb': 2,
                'Mar': 3,
                'Apr': 4,
                'May': 5,
                'Jun': 6,
                'Jul': 7,
                'Aug': 8,
                'Sep': 9,
                'Oct': 10,
                'Nov': 11,
                'Dec': 12,
            }
            grid_row_index, grid_col_index = month_to_number[split_grid_id[0]], int(split_grid_id[1])
            
        print 'processing grid id ' + grid_id

        ds = dataset_list[0]
        base_srs = osr.SpatialReference(ds.GetProjection())
        lat_lng_srs = base_srs.CloneGeogCS()
        coord_transform = osr.CoordinateTransformation(
            base_srs, lat_lng_srs)
        gt = ds.GetGeoTransform()
        grid_resolution = 100 #100km
        
        row_coord = grid_row_index * grid_resolution * 1000 + GLOBAL_UPPER_LEFT_ROW
        col_coord = grid_col_index * grid_resolution * 1000 + GLOBAL_UPPER_LEFT_COL

        lng_coord, lat_coord, _ = coord_transform.TransformPoint(
            col_coord, row_coord)
        write_to_file(','.join(split_line[0:2]) + ',%d-%d,' % (grid_row_index, grid_col_index) + ','.join(split_line[3:11]) +',%f,%f,' % (lat_coord, lng_coord)+','.join(split_line[13:]))

        for (_, header), band, ds, nodata in zip(clipped_raster_list, band_list, dataset_list, nodata_list):

            gt = ds.GetGeoTransform()
            n_rows = ds.RasterYSize
            n_cols = ds.RasterXSize
               
            xoff = int(grid_col_index * (grid_resolution * 1000.0) / (gt[1]))
            yoff = int(grid_row_index * (grid_resolution * 1000.0) / (-gt[5]))
            win_xsize = int((grid_resolution * 1000.0) / (gt[1]))
            win_ysize = int((grid_resolution * 1000.0) / (gt[1]))

            if xoff + win_xsize > n_cols:
                win_xsize = n_cols - xoff
            if yoff + win_ysize > n_rows:
                win_ysize = n_rows - yoff

            block = band.ReadAsArray(
                xoff=xoff, yoff=yoff, win_xsize=win_xsize, win_ysize=win_ysize)
            block_average = numpy.average(block[block != nodata])
            write_to_file(',%f' % block_average)
        write_to_file('\n')
def _make_magnitude_maps(base_uri, table_uri):
    grid_resolution = 100
    
    base_ds = gdal.Open(base_uri)
    
    projection = base_ds.GetProjection()
    driver = gdal.GetDriverByName('GTiff')

    gt = base_ds.GetGeoTransform()
    #gt = (
    #    GLOBAL_UPPER_LEFT_COL, gt[1], gt[2],
    #    GLOBAL_UPPER_LEFT_ROW, gt[4], gt[5]
    #)

    n_rows = base_ds.RasterYSize
    n_cols = base_ds.RasterXSize

    lookup_table = raster_utils.get_lookup_from_csv(table_uri, 'ID100km')
    
    for map_type in ['Magnitude', 'A80', 'A90', 'A95', 'Average Elevation']:

        output_dir, base_filename = os.path.split(base_uri)
        basename = os.path.basename(base_filename)

        output_uri = os.path.join(output_dir, map_type + '.tif')

        n_rows_grid = int(-gt[5] * n_rows / (grid_resolution * 1000.0))
        n_cols_grid = int(gt[1] * n_cols / (grid_resolution * 1000.0))

        new_geotransform = (
            gt[0], grid_resolution * 1000.0, gt[2],
            gt[3], gt[4], -grid_resolution * 1000.0)

        
        
        n_rows_grid = -1
        n_cols_grid = -1
        for grid_id in lookup_table:
            try:
                split_grid_id = grid_id.split('-')
                grid_row_index, grid_col_index = map(int, split_grid_id)
            except ValueError as e:
                month_to_number = {
                    'Jan': 1,
                    'Feb': 2,
                    'Mar': 3,
                    'Apr': 4,
                    'May': 5,
                    'Jun': 6,
                    'Jul': 7,
                    'Aug': 8,
                    'Sep': 9,
                    'Oct': 10,
                    'Nov': 11,
                    'Dec': 12,
                }
                try:
                    grid_row_index, grid_col_index = month_to_number[split_grid_id[0]], int(split_grid_id[1])
                except KeyError as e:
                    continue
            n_rows_grid = max(n_rows_grid, grid_row_index)
            n_cols_grid = max(n_cols_grid, grid_col_index)

        n_rows_grid += 1
        n_cols_grid += 1

        output_ds = driver.Create(
            output_uri.encode('utf-8'), n_cols_grid, n_rows_grid, 1, gdal.GDT_Float32)
        output_ds.SetProjection(projection)
        output_ds.SetGeoTransform(new_geotransform)
        output_band = output_ds.GetRasterBand(1)

        output_nodata = -9999
        output_band.SetNoDataValue(output_nodata)
        output_band.Fill(output_nodata)

        last_time = time.time()
        for grid_id in lookup_table:
            current_time = time.time()
            if current_time - last_time > 5.0:
                print "%s working..." % (map_type,)
                last_time = current_time


            try:
                split_grid_id = grid_id.split('-')
                grid_row_index, grid_col_index = map(int, split_grid_id)
            except ValueError as e:
                month_to_number = {
                    'Jan': 1,
                    'Feb': 2,
                    'Mar': 3,
                    'Apr': 4,
                    'May': 5,
                    'Jun': 6,
                    'Jul': 7,
                    'Aug': 8,
                    'Sep': 9,
                    'Oct': 10,
                    'Nov': 11,
                    'Dec': 12,
                }
                try:
                    grid_row_index, grid_col_index = month_to_number[split_grid_id[0]], int(split_grid_id[1])
                except KeyError as e:
                    continue
            #grid_row_index, grid_col_index = map(int, grid_id.split('-'))

            try:            
                output_band.WriteArray(
                    numpy.array([[float(lookup_table[grid_id][map_type])]]),
                    xoff=grid_col_index, yoff=n_rows_grid - grid_row_index - 1)
            except ValueError:
                pass
    def run(self):
        biomass_ds = gdal.Open(GLOBAL_BIOMASS_URI, gdal.GA_ReadOnly)
        n_rows, n_cols = raster_utils.get_row_col_from_uri(GLOBAL_BIOMASS_URI)

        base_srs = osr.SpatialReference(biomass_ds.GetProjection())
        lat_lng_srs = base_srs.CloneGeogCS()
        coord_transform = osr.CoordinateTransformation(
            base_srs, lat_lng_srs)
        geo_trans = biomass_ds.GetGeoTransform()
        biomass_band = biomass_ds.GetRasterBand(1)
        biomass_nodata = biomass_band.GetNoDataValue()

        forest_table = raster_utils.get_lookup_from_csv(
            self.forest_only_table_uri, 'gridID')
        forest_headers = list(forest_table.values()[0].keys())

        nonexistant_files = []
        for uri in ALIGNED_LAYERS_TO_AVERAGE:
            if not os.path.isfile(uri):
                nonexistant_files.append(uri)
        if len(nonexistant_files) > 0:
            raise Exception(
                "The following files don't exist: %s" %
                (str(nonexistant_files)))

        average_dataset_list = [
            gdal.Open(uri) for uri in ALIGNED_LAYERS_TO_AVERAGE]

        average_band_list = [ds.GetRasterBand(1) for ds in average_dataset_list]
        average_nodata_list = [
            band.GetNoDataValue() for band in average_band_list]

        max_dataset_list = [gdal.Open(uri) for uri in ALIGNED_LAYERS_TO_MAX]
        max_band_list = [ds.GetRasterBand(1) for ds in max_dataset_list]
        max_nodata_list = [band.GetNoDataValue() for band in max_band_list]

        for global_grid_resolution, grid_output_filename in \
                zip(GRID_RESOLUTION_LIST, self.grid_output_file_list):
            try:
                grid_output_file = open(grid_output_filename, 'w')
                grid_output_file.write('grid id,lat_coord,lng_coord')
                for filename in (
                        ALIGNED_LAYERS_TO_AVERAGE + ALIGNED_LAYERS_TO_MAX):
                    grid_output_file.write(
                        ',%s' % os.path.splitext(
                            os.path.basename(filename))[0][len('aligned_'):])
                for header in forest_headers:
                    grid_output_file.write(',%s' % header)
                grid_output_file.write('\n')

                n_grid_rows = int(
                    (-geo_trans[5] * n_rows) / (global_grid_resolution * 1000))
                n_grid_cols = int(
                    (geo_trans[1] * n_cols) / (global_grid_resolution * 1000))

                grid_row_stepsize = int(n_rows / float(n_grid_rows))
                grid_col_stepsize = int(n_cols / float(n_grid_cols))

                for grid_row in xrange(n_grid_rows):
                    for grid_col in xrange(n_grid_cols):
                        #first check to make sure there is biomass at all!
                        global_row = grid_row * grid_row_stepsize
                        global_col = grid_col * grid_col_stepsize
                        global_col_size = min(
                            grid_col_stepsize, n_cols - global_col)
                        global_row_size = min(
                            grid_row_stepsize, n_rows - global_row)
                        array = biomass_band.ReadAsArray(
                            global_col, global_row, global_col_size,
                            global_row_size)
                        if numpy.count_nonzero(array != biomass_nodata) == 0:
                            continue

                        grid_id = '%d-%d' % (grid_row, grid_col)
                        grid_row_center = (
                            -(grid_row + 0.5) * (global_grid_resolution*1000) +
                            geo_trans[3])
                        grid_col_center = (
                            (grid_col + 0.5) * (global_grid_resolution*1000) +
                            geo_trans[0])
                        grid_lng_coord, grid_lat_coord, _ = (
                            coord_transform.TransformPoint(
                                grid_col_center, grid_row_center))
                        grid_output_file.write(
                            '%s,%s,%s' % (grid_id, grid_lat_coord,
                                          grid_lng_coord))


                        #take the average values
                        for band, nodata, layer_uri in zip(
                                average_band_list, average_nodata_list,
                                ALIGNED_LAYERS_TO_AVERAGE +
                                ALIGNED_LAYERS_TO_MAX):
                            nodata = band.GetNoDataValue()
                            array = band.ReadAsArray(
                                global_col, global_row, global_col_size,
                                global_row_size)
                            layer_name = os.path.splitext(
                                os.path.basename(layer_uri)) \
                            [0][len('aligned_'):]

                            pure_average_layers = [
                                'global_elevation', 'global_water_capacity',
                                'fi_average', 'lighted_area_luminosity',
                                'glbctd1t0503m', 'glbgtd1t0503m',
                                'glbpgd1t0503m', 'glbshd1t0503m', 'glds00ag',
                                'glds00g']
                            if layer_name not in pure_average_layers:
                                array[array == nodata] = 0.0
                            valid_values = array[array != nodata]
                            if valid_values.size != 0:
                                value = numpy.average(valid_values)
                            else:
                                value = -9999.
                            grid_output_file.write(',%f' % value)

                        #take the mode values
                        for band, nodata in zip(max_band_list, max_nodata_list):
                            nodata = band.GetNoDataValue()
                            array = band.ReadAsArray(
                                global_col, global_row, global_col_size,
                                global_row_size)
                            #get the most common value
                            valid_values = array[array != nodata]
                            if valid_values.size != 0:
                                value = scipy.stats.mode(valid_values)[0][0]
                                grid_output_file.write(',%f' % value)
                            else:
                                grid_output_file.write(',-9999')

                        #add the forest_only values
                        for header in forest_headers:
                            try:
                                value = forest_table[grid_id][header]
                                if type(value) == unicode:
                                    grid_output_file.write(
                                        ',%s' % forest_table[grid_id][header].\
                                        encode('latin-1', 'replace'))
                                else:
                                    grid_output_file.write(
                                        ',%s' % forest_table[grid_id][header])
                            except KeyError:
                                grid_output_file.write(',-9999')


                        grid_output_file.write('\n')
                grid_output_file.close()
            except IndexError as exception:
                grid_output_file.close()
                os.remove(grid_output_filename)
                raise exception