subprocess.check_call(["python", '-m', 'pip', 'install', 'geemap'])

# Checks whether this notebook is running on Google Colab
try:
    import google.colab
    import geemap.eefolium as emap
except:
    import geemap as emap

# Authenticates and initializes Earth Engine
import ee

try:
    ee.Initialize()
except Exception as e:
    ee.Authenticate()
    ee.Initialize()

# %%
"""
## Create an interactive map 
The default basemap is `Google Satellite`. [Additional basemaps](https://github.com/giswqs/geemap/blob/master/geemap/geemap.py#L13) can be added using the `Map.add_basemap()` function. 
"""

# %%
Map = emap.Map(center=[40, -100], zoom=4)
Map.add_basemap('ROADMAP')  # Add Google Map
Map

# %%
"""
class irrigation30():

    # Trigger the authentication flow.
    ee.Authenticate()

    # Set the max number of samples used in the clustering
    maxSample = 100000
    # Technically, resolution can be a parameter in __init___
    #     But we did not fully test resolutions different from 30 m.
    resolution = 30
    # Reference: https://hess.copernicus.org/articles/19/4441/2015/hessd-12-1329-2015.pdf
    # "If NDVI at peak is less than 0.40, the peak is not counted as cultivation."
    #     The article uses 10-day composite NDVI while we use montly NDVI.
    #     To account for averaging effect, our threshold is slightly lower than 0.4.
    crop_ndvi_threashold = 0.3
    # Estimated based on http://www.fao.org/3/s2022e/s2022e07.htm#TopOfPage
    water_need_threshold = 100
    # Rename ndvi bands to the following
    ndvi_lst = ['ndvi' + str(i).zfill(2) for i in range(1, 13)]
    # Give descriptive name for the month
    month_lst = [
        'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct',
        'Nov', 'Dec'
    ]
    # List of colors used to plot each cluster
    cluster_color = [
        'red', 'blue', 'orange', 'yellow', 'darkgreen', 'lightgreen',
        'lightblue', 'purple', 'pink', 'lightgray'
    ]

    def __init__(self,
                 center_lat=43.771114,
                 center_lon=-116.736866,
                 edge_len=0.005,
                 year=2018,
                 maxClusters_set=2):
        '''
        Parameters: 
            center_lat: latitude for the location coordinate
            center_lon: longitude for the location coordinate
            edge_len: edge length for the rectangle given the location coordinates
            year: year the satellite data should pull images for
            maxClusters_set: should be for range 2-10'''

        # Initialize the library.
        ee.Initialize()

        # error handle parameter issues
        if type(center_lat) == float:
            self.center_lat = center_lat
        else:
            raise ValueError('Please enter float value for latitude')
            exit()

        if type(center_lon) == float:
            self.center_lon = center_lon
        else:
            raise ValueError('Please enter float value for longitude')
            exit()

        if (type(edge_len) == float
                and (edge_len <= 0.5 and edge_len >= 0.005)):
            self.edge_len = edge_len
        else:
            raise ValueError(
                'Please enter float value for edge length between 0.5 and 0.005'
            )
            exit()

        # (range is 2017 to year prior)
        if ((type(year) == int)
                and (year >= 2017 and year <= int(time.strftime("%Y")) - 1)):
            self.year = year
        else:
            raise ValueError(
                'Please enter integer value for year > 2017 and less than current year'
            )
            exit()

        # n_clusters (2-10)
        if ((type(maxClusters_set) == int)
                and (maxClusters_set >= 2 and maxClusters_set <= 10)):
            self.maxClusters_set = maxClusters_set
        else:
            raise ValueError(
                'Please enter integer value for resolution greater than or equal to 10'
            )
            exit()

        # initialize remaining variables
        self.label = []
        self.comment = dict()
        self.avg_ndvi = np.zeros((2, 12))
        self.temperature_max = []
        self.temperature_min = []
        self.temperature_avg = []
        self.precipitation = []
        self.image = ee.Image()
        self.nClusters = 0
        self.simple_label = []

        # Create the bounding box using GEE API
        self.aoi_ee = self.__create_bounding_box_ee()
        # Estimate the area of interest
        self.dist_lon = self.__calc_distance(
            self.center_lon - self.edge_len / 2, self.center_lat,
            self.center_lon + self.edge_len / 2, self.center_lat)
        self.dist_lat = self.__calc_distance(
            self.center_lon, self.center_lat - self.edge_len / 2,
            self.center_lon, self.center_lat + self.edge_len / 2)
        print(
            'The selected area is approximately {:.2f} km by {:.2f} km'.format(
                self.dist_lon, self.dist_lat))

        # Estimate the amount of pixels used in the clustering algorithm
        est_total_pixels = round(self.dist_lat * self.dist_lon * (1000**2) /
                                 ((irrigation30.resolution)**2))
        self.nSample = min(irrigation30.maxSample, est_total_pixels)
        #         print('The estimated percentage of pixels used in the model is {:.0%}.'.format(self.nSample/est_total_pixels))

        # hard-code a few things
        # base_asset_directory is where we are going to store output images
        self.base_asset_directory = "users/mbrimmer/w210_irrigated_croplands"
        self.model_projection = "EPSG:4326"
        self.testing_asset_folder = self.base_asset_directory + '/testing/'

    def __create_bounding_box_ee(self):
        '''Creates a rectangle for pulling image information using center coordinates and edge_len'''
        return ee.Geometry.Rectangle([
            self.center_lon - self.edge_len / 2,
            self.center_lat - self.edge_len / 2,
            self.center_lon + self.edge_len / 2,
            self.center_lat + self.edge_len / 2
        ])

    def __create_bounding_box_shapely(self):
        '''Returns a box for coordinates to plug in as an image add-on layer'''
        return box(self.center_lon - self.edge_len / 2,
                   self.center_lat - self.edge_len / 2,
                   self.center_lon + self.edge_len / 2,
                   self.center_lat + self.edge_len / 2)

    @staticmethod
    def __calc_distance(lon1, lat1, lon2, lat2):
        '''Calculates the distance between 2 coordinates'''
        # Reference: https://stackoverflow.com/questions/19412462/getting-distance-between-two-points-based-on-latitude-longitude
        # approximate radius of earth in km
        R = 6373.0
        lon1 = radians(lon1)
        lat1 = radians(lat1)
        lon2 = radians(lon2)
        lat2 = radians(lat2)
        dlon = lon2 - lon1
        dlat = lat2 - lat1
        a = sin(dlat / 2)**2 + cos(lat1) * cos(lat2) * sin(dlon / 2)**2
        c = 2 * atan2(sqrt(a), sqrt(1 - a))
        distance = R * c
        return distance

    def __pull_Sentinel2_data(self):
        '''Output monthly Sentinel image dataset for a specified area with NDVI readings for the year 
        merged with GFSAD30 and GFSAD1000 information'''
        band_blue = 'B2'  #10m
        band_green = 'B3'  #10m
        band_red = "B4"  #10m
        band_nir = 'B8'  #10m

        start_date = str(self.year) + '-1-01'
        end_date = str(self.year) + '-12-31'

        # Create image collection that contains the area of interest
        Sentinel_IC = (ee.ImageCollection('COPERNICUS/S2').filterDate(
            start_date,
            end_date).filterBounds(self.aoi_ee).select(band_nir, band_red))

        # Get GFSAD30 image and clip to the area of interest
        GFSAD30_IC = ee.ImageCollection("users/ajsohn/GFSAD30").filterBounds(
            self.aoi_ee)
        GFSAD30_img = GFSAD30_IC.max().clip(self.aoi_ee)

        def __calc_NDVI(img):
            '''A function to compute Normalized Difference Vegetation Index'''
            ndvi = ee.Image(img.normalizedDifference([
                band_nir, band_red
            ])).rename(["ndvi"]).copyProperties(img, img.propertyNames())
            composite = img.addBands(ndvi)
            return composite

        def __get_by_month_data(img):
            '''Returns an image after merging the ndvi readings and GFSAD30 data per month'''
            months = ee.List.sequence(1, 12)
            byMonth = ee.ImageCollection.fromImages(
                months.map(
                    lambda m: img.filter(ee.Filter.calendarRange(
                        m, m, 'month')).median().set('month', m)).flatten())

            # Take all the satellite bands that have been split into months
            # as different images in collection (byMonth), and merge into different bands
            def __mergeBands(image, previous):
                '''Returns an image after merging the image with previous image'''
                return ee.Image(previous).addBands(image).copyProperties(
                    image, image.propertyNames())

            merged = byMonth.iterate(__mergeBands, ee.Image())
            return ee.Image(merged).select(
                ['ndvi'] + ['ndvi_' + str(i) for i in range(1, 12)],
                irrigation30.ndvi_lst)

        # Apply the calculation of NDVI
        Sentinel_IC = Sentinel_IC.map(__calc_NDVI).select('ndvi')

        # ---------- GET MONTHLY DATA ---------
        # Get Sentinel-2 monthly data
        # 0 = water, 1 = non-cropland, 2 = cropland, 3 = 'no data'
        byMonth_img = __get_by_month_data(Sentinel_IC) \
                        .addBands(GFSAD30_img.rename(['gfsad30'])) \
                        .addBands(ee.Image("USGS/GFSAD1000_V1").rename(['gfsad1000'])) \
                        .clip(self.aoi_ee)

        # Mask the cropland
        cropland = byMonth_img.select('gfsad30').eq(2)
        byMonth_img_masked = byMonth_img.mask(cropland)

        return byMonth_img_masked

    def __pull_TerraClimate_data(self, band, multiplier=1):
        '''Output monthly TerraClimate image dataset for a specified area for the year'''
        start_date = str(self.year) + '-1-01'
        end_date = str(self.year) + '-12-31'

        # Create image collection that contains the area of interest
        TerraClimate_IC = (
            ee.ImageCollection("IDAHO_EPSCOR/TERRACLIMATE").filterDate(
                start_date, end_date).filterBounds(self.aoi_ee).select(band))

        def __get_by_month_data(img):
            '''Returns an image after merging the band readings per month'''
            months = ee.List.sequence(1, 12)
            byMonth = ee.ImageCollection.fromImages(
                months.map(
                    lambda m: img.filter(ee.Filter.calendarRange(
                        m, m, 'month')).median().set('month', m)).flatten())

            # Take all the satellite bands that have been split into months
            # as different images in collection (byMonth), and merge into different bands
            def __mergeBands(image, previous):
                '''Returns an image after merging the image with previous image'''
                return ee.Image(previous).addBands(image).copyProperties(
                    image, image.propertyNames())

            merged = byMonth.iterate(__mergeBands, ee.Image())
            return ee.Image(merged).select(
                [band] + [band + '_' + str(i) for i in range(1, 12)],
                [band + str(i).zfill(2) for i in range(1, 13)])

        # Get TerraClimate monthly data
        byMonth_img = __get_by_month_data(TerraClimate_IC).clip(self.aoi_ee)

        # Calculate the average value by month
        climate_dict = byMonth_img.reduceRegion(
            reducer=ee.Reducer.mean(),
            geometry=self.aoi_ee,
            maxPixels=1e13,
            scale=irrigation30.resolution).getInfo()
        climate_df = pd.DataFrame(
            [climate_dict],
            columns=[band + str(i).zfill(2) for i in range(1, 13)],
            index=[0])
        climate_arr = climate_df.to_numpy() * multiplier

        return climate_arr

    def __identify_peak(self, y_raw):
        '''Returns peak values and the month for peaking'''
        # Peaks cannot be identified if it's the first or last number in a series
        # To resolve this issue, we copy the series three times
        y = np.concatenate((y_raw, y_raw, y_raw))
        x = np.linspace(0, 35, num=36, endpoint=True)
        peak_index_raw, peak_value_raw = find_peaks(
            y, height=irrigation30.crop_ndvi_threashold)
        # Sometimes there are multiple peaks in a single crop season
        #
        index_diff = np.diff(peak_index_raw)
        peak_grp = [0]
        counter = 0
        for i in index_diff:
            if i == 2:
                peak_grp.append(counter)
            else:
                counter += 1
                peak_grp.append(counter)
        peak_grp_series = pd.Series(peak_grp, name='peak_grp')
        peak_index_series = pd.Series(peak_index_raw, name='peak_index')
        peak_value_series = pd.Series(peak_value_raw['peak_heights'],
                                      name='peak_value')
        peak_grp_df = pd.concat(
            [peak_grp_series, peak_index_series, peak_value_series], axis=1)
        peak_grp_agg_df = peak_grp_df.groupby('peak_grp').agg({
            'peak_index':
            np.mean,
            'peak_value':
            np.max
        })
        peak_index = peak_grp_agg_df['peak_index'].to_numpy()
        peak_value = peak_grp_agg_df['peak_value'].to_numpy()

        peak_lst = [(int(i - 12), irrigation30.month_lst[int(i - 12)], j)
                    for i, j in zip(peak_index, peak_value)
                    if i >= 12 and i < 24]
        final_peak_index = [i[0] for i in peak_lst]
        final_peak_month = [i[1] for i in peak_lst]
        final_peak_value = [i[2] for i in peak_lst]
        return final_peak_index, final_peak_month, final_peak_value

    def __identify_label(self, cluster_result):
        '''Plugs in labels for the irrigated and rainfed areas'''
        def __identify_surrounding_month(value, diff):
            '''For the peaked month returns surrounding month data'''
            new_value = value + diff
            if new_value < 0:
                new_value += 12
            elif new_value >= 12:
                new_value -= 12
            return int(new_value)

        def __calc_effective_precipitation(P):
            '''Calculates and prints irrigation labels based on effective precipitation and temperatures'''
            # Reference:
            # Pe = 0.8 P - 25 if P > 75 mm/month
            # Pe = 0.6 P - 10 if P < 75 mm/month
            if P >= 75:
                Pe = 0.8 * P - 25
            else:
                Pe = max(0.6 * P - 10, 0)
            return Pe

        self.label = []
        for i in range(self.nClusters):
            final_peak_index, final_peak_month, final_peak_value = self.__identify_peak(
                self.avg_ndvi[i])
            if len(final_peak_index) == 0:
                self.label.append('Cluster {}: Rainfed'.format(i))
                self.comment[i] = 'rainfed'
            else:
                temp_label = []
                temp_comment = '{}-crop cycle annually | '.format(
                    len(final_peak_index))
                if len(self.precipitation) == 0:
                    self.precipitation = self.__pull_TerraClimate_data('pr')[0]
                if len(self.temperature_max) == 0:
                    self.temperature_max = self.__pull_TerraClimate_data(
                        'tmmx', multiplier=0.1)[0]
                    self.temperature_min = self.__pull_TerraClimate_data(
                        'tmmn', multiplier=0.1)[0]
                self.temperature_avg = np.mean(
                    [self.temperature_max, self.temperature_min], axis=0)
                for p in range(len(final_peak_index)):
                    p_index = final_peak_index[p]
                    # Calcuate the precipiration the month before the peak and at the peak
                    # Depending on whether it's Fresh harvested crop or Dry harvested crop, the water need after the mid-season is different
                    # Reference: http://www.fao.org/3/s2022e/s2022e02.htm#TopOfPage
                    p_lst = [
                        __identify_surrounding_month(p_index, -1), p_index
                    ]
                    pr_mean = self.precipitation[p_lst].mean()
                    # Lower temperature reduces water need
                    # Reference: http://www.fao.org/3/s2022e/s2022e02.htm#TopOfPage
                    if self.temperature_avg[p_lst].mean() < 15:
                        temperature_adj = 0.7
                    else:
                        temperature_adj = 1
                    if pr_mean >= irrigation30.water_need_threshold * temperature_adj:
                        temp_label.append('Rainfed')
                        temp_comment = temp_comment + 'rainfed around {}; '.format(
                            final_peak_month[p])
                    else:
                        temp_label.append('Irrigated')
                        temp_comment = temp_comment + 'irrigated around {}; '.format(
                            final_peak_month[p])
                self.label.append('Cluster {}: '.format(i) +
                                  '+'.join(temp_label))
                self.comment[i] = temp_comment
        self.simple_label = [
            'Irrigated' if 'Irrigated' in i else 'Rainfed' for i in self.label
        ]
        self.image = self.image.addBands(
            ee.Image(cluster_result.select('cluster')).rename('prediction'))

    def plot_precipitation(self):
        '''Plots precepitation from TerraClimate'''
        if len(self.precipitation) == 0:
            self.precipitation = self.__pull_TerraClimate_data('pr')[0]
        fig, ax = plt.subplots(figsize=(12, 6))
        plt.plot(irrigation30.month_lst,
                 self.precipitation,
                 label='Precipitation')
        plt.legend()

    def plot_temperature_max_min(self):
        '''Plots max and min temperature from TerraClimate'''
        self.temperature_max = self.__pull_TerraClimate_data('tmmx',
                                                             multiplier=0.1)[0]
        self.temperature_min = self.__pull_TerraClimate_data('tmmn',
                                                             multiplier=0.1)[0]
        fig, ax = plt.subplots(figsize=(12, 6))
        plt.plot(irrigation30.month_lst,
                 self.temperature_max,
                 label='Max Temperature')
        plt.plot(irrigation30.month_lst,
                 self.temperature_min,
                 label='Min Temperature')
        plt.legend()

    def fit_predict(self):
        '''Builds model using startified datapoints from sampled ndvi dataset for training'''

        try:
            self.image = self.__pull_Sentinel2_data()
        except:
            raise RuntimeError(
                'GEE will run into issues due to missing images')

        training_FC = self.image.cast({'gfsad30':"int8"},['gfsad30', 'gfsad1000']+irrigation30.ndvi_lst)\
                        .stratifiedSample(region=self.aoi_ee, classBand = 'gfsad30', numPoints = self.nSample,
                        classValues = [0, 1, 3],
                        classPoints = [0, 0, 0],
                        scale=irrigation30.resolution)\
                        .select(irrigation30.ndvi_lst)

        # Instantiate the clusterer and train it.
        clusterer = ee.Clusterer.wekaKMeans(self.maxClusters_set).train(
            training_FC, inputProperties=irrigation30.ndvi_lst)
        # wekaCascadeKMeans takes much longer to run when maxClusters is greater than minClusters
        #         clusterer = ee.Clusterer.wekaCascadeKMeans(minClusters=2, maxClusters=self.maxClusters_set).train(training_FC, inputProperties=irrigation30.ndvi_lst)
        #         clusterer = ee.Clusterer.wekaXMeans(minClusters=2, maxClusters=self.maxClusters_set).train(training_FC, inputProperties=irrigation30.ndvi_lst)

        # Cluster the input using the trained clusterer.
        cluster_result = self.image.cluster(clusterer)

        print('Model building...')
        cluster_output = dict()
        for i in range(0, self.maxClusters_set):
            cluster_output[i] = self.image.select(irrigation30.ndvi_lst).mask(
                cluster_result.select('cluster').eq(i)).reduceRegion(
                    reducer=ee.Reducer.mean(),
                    geometry=self.aoi_ee,
                    maxPixels=1e13,
                    scale=30).getInfo()
            if cluster_output[i]['ndvi01'] == None:
                self.nClusters = i
                del cluster_output[i]
                break
            elif i == self.maxClusters_set - 1:
                self.nClusters = self.maxClusters_set

        # Reference: https://stackoverflow.com/questions/45194934/eval-fails-in-list-comprehension
        globs = globals()
        locs = locals()
        cluster_df = pd.DataFrame(
            [
                eval('cluster_output[{}]'.format(i), globs, locs)
                for i in range(0, self.nClusters)
            ],
            columns=irrigation30.ndvi_lst,
            index=['Cluster_' + str(i) for i in range(0, self.nClusters)])

        self.avg_ndvi = cluster_df.to_numpy()

        self.__identify_label(cluster_result)

        print('Model complete')

    def plot_map(self):
        '''Plot folium map using GEE api - the map includes are of interest box and associated ndvi readings'''
        def add_ee_layer(self, ee_object, vis_params, show, name):
            '''Checks if image object classifies as ImageCollection, FeatureCollection, Geometry or single Image
            and adds to folium map accordingly'''
            try:
                if isinstance(ee_object, ee.image.Image):
                    map_id_dict = ee.Image(ee_object).getMapId(vis_params)
                    folium.raster_layers.TileLayer(
                        tiles=map_id_dict['tile_fetcher'].url_format,
                        attr='Google Earth Engine',
                        name=name,
                        overlay=True,
                        control=True,
                        show=show).add_to(self)
                elif isinstance(ee_object, ee.imagecollection.ImageCollection):
                    ee_object_new = ee_object.median()
                    map_id_dict = ee.Image(ee_object_new).getMapId(vis_params)
                    folium.raster_layers.TileLayer(
                        tiles=map_id_dict['tile_fetcher'].url_format,
                        attr='Google Earth Engine',
                        name=name,
                        overlay=True,
                        control=True,
                        show=show).add_to(self)
                elif isinstance(ee_object, ee.geometry.Geometry):
                    folium.GeoJson(data=ee_object.getInfo(),
                                   name=name,
                                   overlay=True,
                                   control=True).add_to(self)
                elif isinstance(ee_object,
                                ee.featurecollection.FeatureCollection):
                    ee_object_new = ee.Image().paint(ee_object, 0, 2)
                    map_id_dict = ee.Image(ee_object_new).getMapId(vis_params)
                    folium.raster_layers.TileLayer(
                        tiles=map_id_dict['tile_fetcher'].url_format,
                        attr='Google Earth Engine',
                        name=name,
                        overlay=True,
                        control=True,
                        show=show).add_to(self)

            except:
                print("Could not display {}".format(name))

        # Add EE drawing method to folium.
        folium.Map.add_ee_layer = add_ee_layer

        myMap = folium.Map(location=[self.center_lat, self.center_lon],
                           zoom_start=8)
        aoi_shapely = self.__create_bounding_box_shapely()
        folium.GeoJson(aoi_shapely, name="Area of Interest").add_to(myMap)
        visParams = {'min': 0, 'max': 1, 'palette': ['yellow', 'green']}
        myMap.add_ee_layer(self.image.select('prediction'),
                           visParams,
                           show=True,
                           name='Prediction')
        #     0: Non-croplands (black)
        #     1: Croplands: irrigation major (green)
        #     2: Croplands: irrigation minor (lighter green)
        #     3: Croplands: rainfed (yellow)
        #     4: Croplands: rainfed, minor fragments (yellow orange)
        #     5: Croplands: rainfed, rainfed, very minor fragments (orange)
        visParams = {
            'min': 0,
            'max': 5,
            'palette':
            ['black', 'green', 'a9e1a9', 'yellow', 'ffdb00', '#ffa500']
        }
        myMap.add_ee_layer(self.image.select('gfsad1000'),
                           visParams,
                           show=False,
                           name='GFSAD1000')
        visParams = {'min': 0, 'max': 1, 'palette': ['red', 'yellow', 'green']}
        for i in range(1, 13):
            temp_band = 'ndvi' + str(i).zfill(2)
            month_label = irrigation30.month_lst[i - 1]
            myMap.add_ee_layer(self.image.select(temp_band),
                               visParams,
                               show=False,
                               name='NDVI ' + month_label)
        myMap.add_child(folium.LayerControl())
        folium.Marker([self.center_lat, self.center_lon],
                      tooltip='center').add_to(myMap)

        print('============ Prediction Layer Legend ============')
        # print the comments for each cluster
        for i in range(self.nClusters):
            print('Cluster {} ({}): {}'.format(i,
                                               irrigation30.cluster_color[i],
                                               self.comment[i]))
        print('============ GFSAD1000 Layer Legend ============')
        print('Croplands: irrigation major (green)')
        print('Croplands: irrigation minor (lighter green)')
        print('Croplands: rainfed (yellow)')
        print('Croplands: rainfed, minor fragments (yellow orange)')
        print('Croplands: rainfed, rainfed, very minor fragments (orange)')
        print('================================================')
        return myMap

    def plot_avg_ndvi(self):
        '''Plotting for ndvi readings'''
        fig, ax = plt.subplots(figsize=(12, 6))
        for i in range(0, self.nClusters):
            plt.plot(irrigation30.month_lst,
                     self.avg_ndvi[i],
                     label=self.label[i],
                     color=irrigation30.cluster_color[i])
        plt.legend()

    def write_image_asset(self, image_asset_id, write_binary_version=False):
        '''Writes predicted image out as an image to Google Earth Engine as an asset'''
        image_asset_id = self.base_asset_directory + '/' + image_asset_id

        if write_binary_version == False:
            task = ee.batch.Export.image.toAsset(crs=self.model_projection,
                                                 region=self.aoi_ee,
                                                 image=self.image,
                                                 scale=30,
                                                 assetId=image_asset_id,
                                                 maxPixels=1e13)
            task.start()
        else:
            task = ee.batch.Export.image.toAsset(crs=self.model_projection,
                                                 region=self.aoi_ee,
                                                 image=self.binary_image,
                                                 scale=30,
                                                 assetId=image_asset_id,
                                                 maxPixels=1e13)
            task.start()

    def write_image_google_drive(self, filename):
        '''Writes predicted image out as an image to Google Drive as a TIF file'''
        task = ee.batch.Export.image.toDrive(crs=self.model_projection,
                                             region=self.aoi_ee,
                                             image=self.predicted_image,
                                             scale=30,
                                             description=filename,
                                             maxPixels=1e13)
        print("Writing To Google Drive filename= ", filename)
        task.start()
Example #3
0
def get_gee_data(aoi, date_range=["2020-05-01", "2020-07-01"], mode="sentinel_raw",
                 band_names=["B2", "B3", "B4", "B8"]):
    """ download images from google earth engine as zip file
    Parameters
    ----------
    aoi : area of interest as list of [xcoord,ycoord] points
    date_range : list of [start date, end_date] in 'YYYY-MM-DD' format
    mode : 'sentinel_raw' for satellite images, 'global_land_cover' for copernicus glc maps, 'ndvi' for vegetation time series
    band_names : only for mode == sentinel_raw. List of band to keep from the original image defaults to ["B2", "B3", "B4", "B8"]
    Returns None. saves zip file with image in data/raw folder
    """
    # Initialize the Earth Engine module.
    try:
        ee.Initialize()
    except ee.ee_exception.EEException:
        print("MISSING credentials!!!! \n you have to authenticate to Google earth engine with the following account:")
        print("email account: [email protected]")
        print("pw: LandPro2021")
        ee.Authenticate()

    # Area of interest as gee object
    aoi_obj = ee.Geometry.Polygon([aoi])

    print(f"Downloading {mode} image for coordinates {aoi}")
    # date_range as gee object
    start_date = ee.Date(date_range[0])
    end_date = ee.Date(date_range[1])
    if mode == "sentinel_raw":
        # get sentinel collection
        sent2 = ee.ImageCollection(ee.ImageCollection("COPERNICUS/S2_SR"))
        sent_coll = sent2.filterBounds(aoi_obj).filterDate(start_date, end_date)
        # apply cloud removal
        # map function over collection
        cloud_free_coll = sent_coll.map(mask_clouds)
        # merge image using mean
        fin_img = cloud_free_coll.mean().select(band_names)

    # download global land cover
    if mode == "global_land_cover":
        glc = ee.ImageCollection("COPERNICUS/Landcover/100m/Proba-V-C3/Global")
        fin_img = ee.Image(glc.toList(10).reverse().get(0)).clip(aoi_obj)

    # download ndvi time series
    if mode == "ndvi":
        end_date = dt.now()
        start_date = end_date - timedelta(days=365)
        start_date_str = start_date.strftime("%Y-%m-%d")
        end_date_str = end_date.strftime("%Y-%m-%d")
        print(start_date_str, end_date_str)

        # get ndvi collection
        sent2 = ee.ImageCollection(ee.ImageCollection("COPERNICUS/S2_SR"))
        sent_coll = sent2.filterBounds(aoi_obj).filterDate(start_date, end_date)
        sent_coll = sent_coll.filterMetadata("CLOUDY_PIXEL_PERCENTAGE", "less_than", 30)
        cloud_free_coll = sent_coll.map(mask_clouds)
        ndvi_coll = cloud_free_coll.map(
                            lambda img: img.normalizedDifference(["B8", "B4"])\
                                   .clip(aoi_obj)\
                                   .set("system:time_start", img.get("system:time_start"))
                                   )
        # get list of dates 12 month
        start_date = ee.Date(ee.List(ndvi_coll.get("date_range")).get(0))
        end_date = ee.Date(ee.List(ndvi_coll.get("date_range")).get(1))
        diff = end_date.difference(start_date, "month").round()
        date_seq = ee.List.sequence(1, diff, 1).map(lambda delay: start_date.advance(delay, "month") )
        print(date_seq.getInfo())
        # aggregate monthly ndvi

        monthly_ndvi_list = aggregate_ndvi(date_seq, ndvi_coll)

        fin_img = ee.ImageCollection.fromImages(monthly_ndvi_list).toBands()
        print(fin_img.getInfo())

        # download image
    link = fin_img.getDownloadURL({
        'scale': 10,
        'crs': 'EPSG:4326',
        'fileFormat': 'GeoTIFF',
        'region': aoi_obj})
    return link
Example #4
0
    def __init__(self, **kwargs):

        # Authenticates Earth Engine and initialize an Earth Engine session
        try:
            ee.Initialize()
        except Exception as e:
            ee.Authenticate()
            ee.Initialize()

        # Default map center location and zoom level
        latlon = [40, -100]
        zoom = 4

        # Interchangeable parameters between ipyleaflet and folium
        if 'location' in kwargs.keys():
            kwargs['center'] = kwargs['location']
            kwargs.pop('location')
        if 'center' in kwargs.keys():
            latlon = kwargs['center']
        else:
            kwargs['center'] = latlon

        if 'zoom_start' in kwargs.keys():
            kwargs['zoom'] = kwargs['zoom_start']
            kwargs.pop('zoom_start')
        if 'zoom' in kwargs.keys():
            zoom = kwargs['zoom']
        else:
            kwargs['zoom'] = zoom

        # Inherit the ipyleaflet Map class
        super().__init__(**kwargs)
        self.scroll_wheel_zoom = True
        self.layout.height = '550px'

        layer_control = LayersControl(position='topright')
        self.add_control(layer_control)
        self.layer_control = layer_control

        scale = ScaleControl(position='bottomleft')
        self.add_control(scale)
        self.scale_control = scale

        fullscreen = FullScreenControl()
        self.add_control(fullscreen)
        self.fullscreen_control = fullscreen

        measure = MeasureControl(position='bottomleft',
                                 active_color='orange',
                                 primary_length_unit='kilometers')
        self.add_control(measure)
        self.measure_control = measure

        self.add_layer(ee_basemaps['ROADMAP'])

        draw_control = DrawControl(
            marker={'shapeOptions': {
                'color': '#0000FF'
            }},
            rectangle={'shapeOptions': {
                'color': '#0000FF'
            }},
            circle={'shapeOptions': {
                'color': '#0000FF'
            }},
            circlemarker={},
        )

        self.draw_count = 0  # The number of shapes drawn by the user using the DrawControl
        # The list of Earth Engine Geometry objects converted from geojson
        self.draw_features = []
        # The Earth Engine Geometry object converted from the last drawn feature
        self.draw_last_feature = None
        self.draw_layer = None

        self.plot_widget = None  # The plot widget for plotting Earth Engine data
        self.plot_control = None  # The plot control for interacting plotting
        self.random_marker = None

        self.ee_layers = []
        self.ee_layer_names = []
        self.ee_raster_layers = []
        self.ee_raster_layer_names = []

        # Handles draw events
        def handle_draw(target, action, geo_json):
            try:
                self.draw_count += 1
                geom = geojson_to_ee(geo_json, False)
                feature = ee.Feature(geom)
                self.draw_last_feature = feature
                self.draw_features.append(feature)
                collection = ee.FeatureCollection(self.draw_features)
                ee_draw_layer = ee_tile_layer(collection, {'color': 'blue'},
                                              'Drawing Features', True, 0.5)

                if self.draw_count == 1:
                    self.add_layer(ee_draw_layer)
                    self.draw_layer = ee_draw_layer
                else:
                    self.substitute_layer(self.draw_layer, ee_draw_layer)
                    self.draw_layer = ee_draw_layer

                draw_control.clear()
            except Exception as e:
                print(e)
                print("There was an error creating Earth Engine Feature.")
                self.draw_count = 0
                self.draw_features = []
                self.draw_last_feature = None
                self.draw_layer = None

        draw_control.on_draw(handle_draw)
        self.add_control(draw_control)
        self.draw_control = draw_control

        # Dropdown widget for plotting
        self.plot_dropdown_control = None
        self.plot_dropdown_widget = None
        self.plot_options = {}

        self.plot_marker_cluster = MarkerCluster(name="Marker Cluster")
        self.plot_coordinates = []
        self.plot_markers = []
        self.plot_last_click = []
        self.plot_all_clicks = []

        # Adds Inspector widget
        inspector_checkbox = widgets.Checkbox(
            value=False,
            description='Use Inspector',
            indent=False,
            layout=widgets.Layout(height='18px'))
        inspector_checkbox.layout.width = '18ex'

        # Adds Plot widget
        plot_checkbox = widgets.Checkbox(
            value=False,
            description='Use Plotting',
            indent=False,
        )
        plot_checkbox.layout.width = '18ex'
        self.plot_checkbox = plot_checkbox

        vb = widgets.VBox(children=[inspector_checkbox, plot_checkbox])

        chk_control = WidgetControl(widget=vb, position='topright')
        self.add_control(chk_control)
        self.inspector_control = chk_control

        self.inspector_checked = inspector_checkbox.value
        self.plot_checked = plot_checkbox.value

        def inspect_chk_changed(b):
            self.inspector_checked = inspector_checkbox.value
            if not self.inspector_checked:
                output.clear_output()

        inspector_checkbox.observe(inspect_chk_changed)

        output = widgets.Output(layout={'border': '1px solid black'})
        output_control = WidgetControl(widget=output, position='topright')
        self.add_control(output_control)

        def plot_chk_changed(button):

            if button['name'] == 'value' and button['new']:
                self.plot_checked = True
                plot_dropdown_widget = widgets.Dropdown(options=list(
                    self.ee_raster_layer_names), )
                plot_dropdown_widget.layout.width = '18ex'
                self.plot_dropdown_widget = plot_dropdown_widget
                plot_dropdown_control = WidgetControl(
                    widget=plot_dropdown_widget, position='topright')
                self.plot_dropdown_control = plot_dropdown_control
                self.add_control(plot_dropdown_control)
            elif button['name'] == 'value' and (not button['new']):
                self.plot_checked = False
                plot_dropdown_widget = self.plot_dropdown_widget
                plot_dropdown_control = self.plot_dropdown_control
                self.remove_control(plot_dropdown_control)
                del plot_dropdown_widget
                del plot_dropdown_control
                if self.plot_control in self.controls:
                    plot_control = self.plot_control
                    plot_widget = self.plot_widget
                    self.remove_control(plot_control)
                    self.plot_control = None
                    self.plot_widget = None
                    del plot_control
                    del plot_widget
                if self.plot_marker_cluster is not None and self.plot_marker_cluster in self.layers:
                    self.remove_layer(self.plot_marker_cluster)

        plot_checkbox.observe(plot_chk_changed)

        def handle_interaction(**kwargs):

            latlon = kwargs.get('coordinates')
            # print(latlon)
            if kwargs.get('type') == 'click' and self.inspector_checked:
                self.default_style = {'cursor': 'wait'}

                sample_scale = self.getScale()
                layers = self.ee_layers

                with output:

                    output.clear_output(wait=True)
                    for index, ee_object in enumerate(layers):
                        xy = ee.Geometry.Point(latlon[::-1])
                        layer_names = self.ee_layer_names
                        layer_name = layer_names[index]
                        object_type = ee_object.__class__.__name__

                        try:
                            if isinstance(ee_object, ee.ImageCollection):
                                ee_object = ee_object.mosaic()
                            elif isinstance(ee_object, ee.geometry.Geometry) or isinstance(ee_object, ee.feature.Feature) \
                                    or isinstance(ee_object, ee.featurecollection.FeatureCollection):
                                ee_object = ee.FeatureCollection(ee_object)

                            if isinstance(ee_object, ee.Image):
                                item = ee_object.reduceRegion(
                                    ee.Reducer.first(), xy,
                                    sample_scale).getInfo()
                                b_name = 'band'
                                if len(item) > 1:
                                    b_name = 'bands'
                                print("{}: {} ({} {})".format(
                                    layer_name, object_type, len(item),
                                    b_name))
                                keys = item.keys()
                                for key in keys:
                                    print("  {}: {}".format(key, item[key]))
                            elif isinstance(ee_object, ee.FeatureCollection):
                                filtered = ee_object.filterBounds(xy)
                                size = filtered.size().getInfo()
                                if size > 0:
                                    first = filtered.first()
                                    props = first.toDictionary().getInfo()
                                    b_name = 'property'
                                    if len(props) > 1:
                                        b_name = 'properties'
                                    print("{}: Feature ({} {})".format(
                                        layer_name, len(props), b_name))
                                    keys = props.keys()
                                    for key in keys:
                                        print("  {}: {}".format(
                                            key, props[key]))
                        except Exception as e:
                            print(e)

                self.default_style = {'cursor': 'crosshair'}
            if kwargs.get('type') == 'click' and self.plot_checked and len(
                    self.ee_raster_layers) > 0:
                plot_layer_name = self.plot_dropdown_widget.value
                layer_names = self.ee_raster_layer_names
                layers = self.ee_raster_layers
                index = layer_names.index(plot_layer_name)
                ee_object = layers[index]

                if isinstance(ee_object, ee.ImageCollection):
                    ee_object = ee_object.mosaic()

                try:
                    self.default_style = {'cursor': 'wait'}
                    plot_options = self.plot_options
                    sample_scale = self.getScale()
                    if 'sample_scale' in plot_options.keys() and (
                            plot_options['sample_scale'] is not None):
                        sample_scale = plot_options['sample_scale']
                    if 'title' not in plot_options.keys():
                        plot_options['title'] = plot_layer_name
                    if ('add_marker_cluster' in plot_options.keys()
                        ) and plot_options['add_marker_cluster']:
                        plot_coordinates = self.plot_coordinates
                        markers = self.plot_markers
                        marker_cluster = self.plot_marker_cluster
                        plot_coordinates.append(latlon)
                        self.plot_last_click = latlon
                        self.plot_all_clicks = plot_coordinates
                        markers.append(Marker(location=latlon))
                        marker_cluster.markers = markers
                        self.plot_marker_cluster = marker_cluster

                    band_names = ee_object.bandNames().getInfo()
                    xy = ee.Geometry.Point(latlon[::-1])
                    dict_values = ee_object.sample(
                        xy,
                        scale=sample_scale).first().toDictionary().getInfo()
                    band_values = list(dict_values.values())
                    self.plot(band_names, band_values, **plot_options)
                    if plot_options['title'] == plot_layer_name:
                        del plot_options['title']
                    self.default_style = {'cursor': 'crosshair'}
                except Exception as e:
                    if self.plot_widget is not None:
                        with self.plot_widget:
                            self.plot_widget.clear_output()
                            print("No data for the clicked location.")
                    else:
                        print(e)
                    self.default_style = {'cursor': 'crosshair'}

        self.on_interaction(handle_interaction)
 def __init__(self, authenticate):
     if authenticate == 1:
         ee.Authenticate()
         ee.Initialize()
     else:
         ee.Initialize()