def init_cells(direction_map):
    cells = []
    nx, ny = direction_map.shape
    for i in xrange(nx):
        cells.append([])
        for j in xrange(ny):
            theCell = Cell()
            theCell.x = i
            theCell.y = j
            cells[-1].append(Cell())

    for i in xrange(nx):
        for j in xrange(ny):
            i_next, j_next = direction_and_value.to_indices(i, j, direction_map[i,j])
            if i_next < 0 or j_next < 0 or i_next == nx or j_next == ny:
                nextCell = None
            else:
                nextCell = cells[i_next][j_next]
            cells[i][j].set_next(nextCell)

    return cells
Esempio n. 2
0
    def _connect_cells(self):
        '''
        Inits and connects underlying cell list
        '''

        flowDirValues = self.ncFile.variables['flow_direction_value'].data
        print self.nx, self.ny


        for i in xrange(self.nx):
            for j in xrange(self.ny):
                self.cells[i][j].drainage_area = self.accumulation_area[i, j]
                iNext, jNext = direction_and_value.to_indices(i, j, flowDirValues[i, j])
                if iNext >= 0 and iNext < self.nx and jNext >= 0 and jNext < self.ny:
                    self.cells[i][j].set_next(self.cells[iNext][jNext])

        #calculate number of previous cells for each cell
        for i in xrange(self.nx):
            for j in xrange(self.ny):
                theCell = self.cells[i][j]
                if theCell.number_of_upstream_cells >= 0:
                    continue
                theCell.calculate_number_of_upstream_cells()
def main():
    """

    """
    skip_ids = ['081007', '081002', "042607", "090605"]

    #comment to plot for all ensemble members
    members.current_ids = []


    #pylab.rcParams.update(params)
    path_format = 'data/streamflows/hydrosheds_euler9/%s_discharge_1970_01_01_00_00.nc'
    #path_format = "data/streamflows/hydrosheds_rk4_changed_partiotioning/%s_discharge_1970_01_01_00_00.nc"
    #path_format = "data/streamflows/piloted_by_ecmwf/ecmwf_nearest_neighbor_discharge_1970_01_01_00_00.nc"
    path_to_analysis_driven = path_format % members.control_id

    simIdToData = {}
    simIdToTimes = {}
    for the_id in members.current_ids:
        thePath = path_format % the_id
        [simIdToData[the_id], simIdToTimes[the_id], i_list, j_list] = data_select.get_data_from_file(thePath)


    old = True #in the old version drainage and lon,lats in the file are 1D


    [ data, times, i_list, j_list ] = data_select.get_data_from_file(path_to_analysis_driven)

    cell_list = []
    ij_to_cell = {}
    prev_cell_indices = []
    tot_rof = None
    if old:
        #surf_rof = data_select.get_data_from_file(path_format % ("aex",), field_name="")
        the_path = path_format % ("aex")
        static_data_path = "data/streamflows/hydrosheds_euler9/infocell9.nc"
        #ntimes x ncells
        tot_rof = data_select.get_field_from_file(the_path, field_name="total_runoff")
        cell_areas = data_select.get_field_from_file(static_data_path, field_name="AREA")

        #convert the runoff to m^3/s
        tot_rof *= 1.0e6 * cell_areas[i_list, j_list] / 1.0e3


        flow_dir_values = data_select.get_field_from_file(static_data_path,
            field_name="flow_direction_value")[i_list, j_list]

        cell_list = map(lambda i, j, the_id: Cell(id = the_id, ix = i, jy = j),
                                i_list, j_list, xrange(len(i_list)))


        ij_to_cell = dict( zip( zip(i_list, j_list), cell_list ))


        for ix, jy, aCell, dir_val in zip( i_list, j_list, cell_list, flow_dir_values):
            i_next, j_next = direction_and_value.to_indices(ix, jy, dir_val)
            the_key = (i_next, j_next)
            if ij_to_cell.has_key(the_key):
                next_cell = ij_to_cell[the_key]
            else:
                next_cell = None
            assert isinstance(aCell, Cell)
            aCell.set_next(next_cell)

        #determine list of indices of the previous cells for each cell
        #in this case they are equal to the ids

        for aCell in cell_list:
            assert isinstance(aCell, Cell)
            prev_cells = aCell.get_upstream_cells()
            prev_cell_indices.append(map(lambda c: c.id, prev_cells))
            prev_cell_indices[-1].append(aCell.id)



    if not old:
        da_2d = data_select.get_field_from_file(path_to_analysis_driven, 'accumulation_area')
        lons = data_select.get_field_from_file(path_to_analysis_driven, field_name = 'longitude')
        lats = data_select.get_field_from_file(path_to_analysis_driven, field_name = 'latitude')
    else:
        lons = polar_stereographic.lons
        lats = polar_stereographic.lats
        da_2d = np.zeros(lons.shape)
        drainage = data_select.get_field_from_file(path_to_analysis_driven, 'drainage')
        for i, j, theDa in zip(i_list, j_list, drainage):
            da_2d[i, j] = theDa




    data_step = timedelta(days = 1)


    stations_dump = 'stations_dump.bin'
    if os.path.isfile(stations_dump):
        print 'unpickling'
        stations = pickle.load(open(stations_dump))
    else:
        stations = read_station_data()
        pickle.dump(stations, open(stations_dump, 'w'))

#   Did this to solve text encoding issues
#    reload(sys)
#    sys.setdefaultencoding('iso-8859-1')


    selected_stations = []
    selected_model_values = []
    selected_station_values = []

    grid_drainages = []
    grid_lons = []
    grid_lats = []
    plot_utils.apply_plot_params(width_pt= None, font_size=9, aspect_ratio=2.5)
    #plot_utils.apply_plot_params(font_size=9, width_pt=None)
    ncols = 2
    gs = gridspec.GridSpec(5, ncols)
    fig = plt.figure()

    assert isinstance(fig, Figure)

    current_subplot = 0

    label1 = "modelled"
    label2 = "observed"
    line1 = None
    line2 = None
    lines_for_mems = None
    labels_for_mems = None
    #fig.subplots_adjust(hspace = 0.9, wspace = 0.4, top = 0.9)




    index_objects = []
    for index, i, j in zip( range(len(i_list)) , i_list, j_list):
        index_objects.append(IndexObject(positionIndex = index, i = i, j = j))

    #sort by latitude
    index_objects.sort( key = lambda x: x.j, reverse = True)

    #simulation id to continuous data map
    simIdToContData = {}
    for the_id in members.all_current:
        simIdToContData[the_id] = {}

    for indexObj in index_objects:
        i = indexObj.i
        j = indexObj.j
        # @type indexObj IndexObject
        index = indexObj.positionIndex
        station = get_corresponding_station(lons[i, j], lats[i, j], da_2d[i, j], stations)


        if station is None or station in selected_stations:
            continue

        #if you want to compare with stations add their ids to the selected
        if station.id not in selected_station_ids:
            continue


        #skip some stations
        if station.id in skip_ids:
            continue


        #try now to find the point with the closest drainage area
#        current_diff = np.abs(station.drainage_km2 - da_2d[i, j])
#        for di in xrange(-1,2):
#            for dj in xrange(-1,2):
#                the_diff = np.abs(station.drainage_km2 - da_2d[i + di, j + dj])
#                if the_diff < current_diff: #select different grid point
#                    current_diff = the_diff
#                    i = i + di
#                    j = j + dj
#                    indexObj.i = i
#                    indexObj.j = j




        #found station plot data
        print station.name


        start_date = max( np.min(times), np.min(station.dates))
        end_date = min( np.max(times),  np.max(station.dates))

        if start_date.day > 1 or start_date.month > 1:
            start_date = datetime(start_date.year + 1, 1, 1,0,0,0)

        if end_date.day < 31 or end_date.month < 12:
            end_date = datetime(end_date.year - 1, 12, 31,0,0,0)



        if end_date < start_date:
            continue


        #select data for years that do not have gaps
        start_year = start_date.year
        end_year = end_date.year
        continuous_station_data = {}
        continuous_model_data = {}
        num_of_continuous_years = 0
        for year in xrange(start_year, end_year + 1):
            # @type station Station
            station_data = station.get_continuous_dataseries_for_year(year)
            if len(station_data) >= 365:
                num_of_continuous_years += 1

                #save station data
                for d, v in station_data.iteritems():
                    continuous_station_data[d] = v

                #save model data
                for t_index, t in enumerate(times):
                    if t.year > year: break
                    if t.year < year: continue
                    continuous_model_data[t] = data[t_index, index]
                #fill the map sim id to cont model data
                for the_id in members.current_ids:
                    #save model data
                    for t_index, t in enumerate(simIdToTimes[the_id]):
                        if t.year > year: break
                        if t.year < year: continue
                        simIdToContData[the_id][t] = simIdToData[the_id][t_index, index]


        #if the length of continuous observation is less than 10 years, skip
        if len(continuous_station_data) < 3650: continue

        print 'Number of continuous years for station %s is %d ' % (station.id, num_of_continuous_years)

        #skip stations with less than 20 years of usable data
        #if num_of_continuous_years < 2:
        #    continue

        selected_stations.append(station)

#        plot_total_precip_for_upstream(i_index = i, j_index = j, station_id = station.id,
#                                        subplot_count = current_subplot,
#                                        start_date = datetime(1980,01,01,00),
#                                        end_date = datetime(1996,12,31,00)
#                                        )

        #tmp (if do not need to replot streamflow)
#        current_subplot += 1
#        continue

        ##Calculate means for each day of year,
        ##as a stamp year we use 2001, ignoring the leap year
        stamp_year = 2001
        start_day = datetime(stamp_year, 1, 1, 0, 0, 0)
        stamp_dates = []
        mean_data_model = []
        mean_data_station = []
        simIdToMeanModelData = {}
        for the_id in members.all_current:
            simIdToMeanModelData[the_id] = []

        for day_number in xrange(365):
            the_day = start_day + day_number * data_step
            stamp_dates.append(the_day)

            model_data_for_day = []
            station_data_for_day = []

            #select model data for each simulation, day
            #and then save mean for each day
            simIdToModelDataForDay = {}
            for the_id in members.current_ids:
                simIdToModelDataForDay[the_id] = []

            for year in xrange(start_year, end_year + 1):
                the_date = datetime(year, the_day.month, the_day.day, the_day.hour, the_day.minute, the_day.second)
                if continuous_station_data.has_key(the_date):
                    model_data_for_day.append(continuous_model_data[the_date])
                    station_data_for_day.append(continuous_station_data[the_date])
                    for the_id in members.current_ids:
                        simIdToModelDataForDay[the_id].append(simIdToContData[the_id][the_date])

            assert len(station_data_for_day) > 0
            mean_data_model.append(np.mean(model_data_for_day))
            mean_data_station.append(np.mean(station_data_for_day))
            for the_id in members.current_ids:
                simIdToMeanModelData[the_id].append(np.mean(simIdToModelDataForDay[the_id]))


         #skip stations with small discharge
        #if np.max(mean_data_station) < 300:
        #    continue

        row = current_subplot// ncols
        col = current_subplot % ncols
        ax = fig.add_subplot(gs[row, col])
        assert isinstance(ax, Axes)
        current_subplot += 1

        #put "Streamflow label on the y-axis"
        if row == 0 and col == 0:
            ax.annotate("Streamflow (${\\rm m^3/s}$)", (0.025, 0.7) , xycoords = "figure fraction",
                rotation = 90, va = "top", ha = "center")

        selected_dates = sorted( continuous_station_data.keys() )
        unrouted_stfl = get_unrouted_streamflow_for(selected_dates = selected_dates,
            all_dates=times, tot_runoff=tot_rof, cell_indices=prev_cell_indices[index])

        unrouted_daily_normals = data_select.get_means_for_stamp_dates(stamp_dates, all_dates= selected_dates,
            all_data=unrouted_stfl)

        #Calculate Nash-Sutcliff coefficient
        mean_data_model = np.array(mean_data_model)
        mean_data_station = np.array( mean_data_station )

        #mod = _get_monthly_means(stamp_dates, mean_data_model)
        #sta = _get_monthly_means(stamp_dates, mean_data_station)

        month_dates = [ datetime(stamp_year, m, 1) for m in xrange(1,13) ]


        line1, = ax.plot(stamp_dates, mean_data_model, linewidth = 3, color = "b")
        #line1, = ax.plot(month_dates, mod, linewidth = 3, color = "b")
        upper_model = np.max(mean_data_model)

        line2, = ax.plot(stamp_dates, mean_data_station, linewidth = 3, color = "r")
        #line2, = ax.plot(month_dates, sta, linewidth = 3, color = "r")

        #line3, = ax.plot(stamp_dates, unrouted_daily_normals, linewidth = 3, color = "y")


        mod = mean_data_model
        sta = mean_data_station

        ns = 1.0 - np.sum((mod - sta) ** 2) / np.sum((sta - np.mean(sta)) ** 2)

        if np.abs(ns) < 0.001:
            ns = 0

        corr_coef = np.corrcoef([mod, sta])[0,1]
        ns_unr = 1.0 - np.sum((unrouted_daily_normals - sta) ** 2) / np.sum((sta - np.mean(sta)) ** 2 )
        corr_unr = np.corrcoef([unrouted_daily_normals, sta])[0, 1]

        da_diff = (da_2d[i, j] - station.drainage_km2) / station.drainage_km2 * 100
        ax.annotate("ns = %.2f\nr = %.2f"
                  % (ns, corr_coef), (0.95, 0.90), xycoords = "axes fraction",
            va = "top", ha = "right",
            font_properties = FontProperties(size = 9)
        )



        #plot member simulation data
        lines_for_mems = []
        labels_for_mems = []

        #lines_for_mems.append(line3)
        #labels_for_mems.append("Unrouted total runoff")


        for the_id in members.current_ids:
            the_line, = ax.plot(stamp_dates, simIdToMeanModelData[the_id], "--", linewidth = 3)
            lines_for_mems.append(the_line)
            labels_for_mems.append(the_id)


        ##calculate mean error
        means_for_members = []
        for the_id in members.current_ids:
            means_for_members.append(np.mean(simIdToMeanModelData[the_id]))





        upper_station = np.max(mean_data_station)
        upper_unr = np.max(unrouted_daily_normals)

        upper = np.max([upper_model, upper_station])
        upper = round(upper / 100 ) * 100
        half = round( 0.5 * upper / 100 ) * 100
        if upper <= 100:
            upper = 100
            half = upper / 2

        print half, upper
        print 10 * '='

        ax.set_yticks([0, half , upper])
        assert isinstance(station, Station)

        print("i = {0}, j = {1}".format(indexObj.i, indexObj.j))
        print(lons[i,j], lats[i,j])
        print("id = {0}, da_sta = {1}, da_mod = {2}, diff = {3} %".format(station.id ,station.drainage_km2, da_2d[i,j], da_diff))

        grid_drainages.append(da_2d[i, j])
        grid_lons.append(lons[i, j])
        grid_lats.append(lats[i, j])

        selected_station_values.append(mean_data_station)
        selected_model_values.append(mean_data_model)



        #plot_swe_for_upstream(i_index = i, j_index = j, station_id = station.id)




        #plt.ylabel("${\\rm m^3/s}$")
        west_east = 'W' if station.longitude < 0 else 'E'
        north_south = 'N' if station.latitude > 0 else 'S'
        title_data = (station.id, np.abs(station.longitude), west_east,
                                  np.abs(station.latitude), north_south)
        ax.set_title('%s: (%3.1f%s, %3.1f%s)' % title_data)


        date_ticks = []
        for month in xrange(1,13):
            the_date = datetime(stamp_year, month, 1)
            date_ticks.append(the_date)
            date_ticks.append(the_date + timedelta(days = 15))
        ax.xaxis.set_ticks(date_ticks)



        major_ticks = ax.xaxis.get_major_ticks()


        for imtl, mtl in enumerate(major_ticks):
            mtl.tick1line.set_visible(imtl % 2 == 0)
            mtl.tick2line.set_visible(imtl % 2 == 0)
            mtl.label1On = (imtl % 4 == 1)

#        ax.xaxis.set_major_locator(
#            mpl.dates.MonthLocator(bymonth = range(2,13,2))
#        )


        ax.xaxis.set_major_formatter(
            mpl.dates.DateFormatter('%b')
        )





    lines = [line1]
    lines.extend(lines_for_mems)
    lines.append(line2)
    lines = tuple( lines )


    labels = [label1]
    labels.extend(labels_for_mems)
    labels.append(label2)
    labels = tuple(labels)

    fig.legend(lines, labels, 'lower right', ncol = 1)
#    fig.text(0.05, 0.5, "Streamflow (${\\rm m^3/s}$)",
#                  rotation=90,
#                  ha = 'center', va = 'center'
#                  )


    fig.tight_layout(pad = 2)
    fig.savefig('performance_error.png')




    
   # assert len(selected_dates_with_gw[0]) == len(selected_station_dates[0])

    do_skill_calculation = True
    if do_skill_calculation:
        calculate_skills(selected_stations,
                        stamp_dates, selected_station_values,
                        selected_model_values,
                        grid_drainages,
                        grid_lons, grid_lats)


    do_plot_selected_stations = True
    if do_plot_selected_stations:
        plot_selected_stations(selected_stations, use_warpimage=False, plot_ts = False,
                               i_list = i_list, j_list = j_list)