Exemple #1
0
def output_population_details(pops, outdir, gen_num,
                              plot_cfg=None  # type: Optional[PlotConfig]
                              ):
    """Output population details, i.e., the simulation data, etc."""
    # Save as json, which can be loaded by json.load()
    # 1. Save the time series simulation data of the entire simulation period
    all_sim_data = list()
    pickle_file = outdir + os.path.sep + 'gen%d_allSimData.pickle' % gen_num
    with open(pickle_file, 'wb') as f:
        for ind in pops:
            all_sim_data.append(ind.sim.data)
        pickle.dump(all_sim_data, f)

    # 2. Save the matched observation-simulation data of validation period,
    #      and the simulation data separately.
    cali_sim_obs_data = list()
    cali_sim_data = list()
    json_file = outdir + os.path.sep + 'gen%d_caliObsData.json' % gen_num
    with open(json_file, 'w', encoding='utf-8') as f:
        for ind in pops:
            ind.cali.sim_obs_data['Gen'] = ind.gen
            ind.cali.sim_obs_data['ID'] = ind.id
            ind.cali.sim_obs_data['var_name'] = ind.cali.vars
            cali_sim_obs_data.append(ind.cali.sim_obs_data)
        json_data = json.dumps(cali_sim_obs_data, indent=4, cls=SpecialJsonEncoder)
        f.write('%s' % json_data)
    pickle_file = outdir + os.path.sep + 'gen%d_caliSimData.pickle' % gen_num
    with open(pickle_file, 'wb') as f:
        for ind in pops:
            cali_sim_data.append(ind.cali.data)
        pickle.dump(cali_sim_data, f)
    # 3. Save the matched observation-simulation data of calibration period,
    #      and the simulation data separately.
    vali_sim_obs_data = list()
    vali_sim_data = list()
    if pops[0].vali.valid:
        json_file = outdir + os.path.sep + 'gen%d_valiObsData.json' % gen_num
        with open(json_file, 'w', encoding='utf-8') as f:
            for ind in pops:
                ind.vali.sim_obs_data['Gen'] = ind.gen
                ind.vali.sim_obs_data['ID'] = ind.id
                ind.vali.sim_obs_data['var_name'] = ind.vali.vars
                vali_sim_obs_data.append(ind.vali.sim_obs_data)
            json_data = json.dumps(vali_sim_obs_data, indent=4, cls=SpecialJsonEncoder)
            f.write('%s' % json_data)
        pickle_file = outdir + os.path.sep + 'gen%d_valiSimData.pickle' % gen_num
        with open(pickle_file, 'wb') as f:
            for ind in pops:
                vali_sim_data.append(ind.vali.data)
            pickle.dump(vali_sim_data, f)
    # 4. Try to plot.
    if plot_cfg is None:
        plot_cfg = PlotConfig()
    try:
        # Calculate 95PPU for current generation, and plot the desired variables, e.g., Q and SED
        calculate_95ppu(cali_sim_obs_data, cali_sim_data, outdir, gen_num,
                        vali_sim_obs_data, vali_sim_data,
                        plot_cfg=plot_cfg)
    except Exception:
        pass
Exemple #2
0
def plot_pareto_fronts_multiple(method_paths,  # type: Dict[AnyStr, AnyStr]
                                sce_name,  # type: AnyStr
                                xname,
                                # type: List[AnyStr, AnyStr, Optional[float], Optional[float]]
                                yname,
                                # type: List[AnyStr, AnyStr, Optional[float], Optional[float]]
                                gens,  # type: List[int]
                                ws,  # type: AnyStr
                                plot_cfg=None  # type: Optional[PlotConfig]
                                ):
    # type: (...) -> None
    """
    Plot Pareto fronts of different methods at a same generation for comparision.

    Args:
        method_paths(OrderedDict): key is method name (which also displayed in legend), value is file path.
        sce_name(str): Scenario ID field name.
        xname(list): the first is x field name in log file, and the second is for plot,
                     the third and forth values are low and high limits (optional).
        yname(list): see xname
        gens(list): generation to be plotted
        ws(string): workspace for output files
        plot_cfg(PlotConfig): Plot settings for matplotlib
    """
    pareto_data = OrderedDict()  # type: OrderedDict[int, Union[List, numpy.ndarray]]
    acc_pop_size = OrderedDict()  # type: Dict[int, int]
    for k, v in viewitems(method_paths):
        v = v + os.path.sep + 'runtime.log'
        pareto_data[k], acc_pop_size[k] = read_pareto_points_from_txt(v, sce_name, xname)
    # print(pareto_data)
    if plot_cfg is None:
        plot_cfg = PlotConfig()
    plot_pareto_fronts(pareto_data, xname[1:], yname[1:], gens, ws, plot_cfg=plot_cfg)
Exemple #3
0
def plot_hypervolumes(hyperv,  # type: Dict[AnyStr, Optional[int, float, List[List[int], List[float]]]]
                      ws,  # type: AnyStr
                      cn=False,  # type: bool # Deprecated! Please use plot_cfg=PlotConfig instead.
                      plot_cfg=None  # type: Optional[PlotConfig]
                      ):
    # type: (...) -> bool
    """Plot hypervolume of multiple optimization methods

    Args:
        hyperv: Dict, key is method name and value is generation IDs list and hypervolumes list
                      Optionally, key-values of 'bottom' and 'top' are allowed.
        ws: Full path of the destination directory
        cn: (Optional) Use Chinese. Deprecated!
        plot_cfg: (Optional) Plot settings for matplotlib
    """
    if plot_cfg is None:
        plot_cfg = PlotConfig()
    plt.rcParams['xtick.direction'] = 'out'
    plt.rcParams['ytick.direction'] = 'out'
    plt.rcParams['font.family'] = plot_cfg.font_name
    generation_str = (u'进化代数' if plot_cfg.plot_cn else 'Generation')
    hyperv_str = (u'Hypervolume 指数' if plot_cfg.plot_cn else 'Hypervolume index')

    # Line styles: https://matplotlib.org/gallery/lines_bars_and_markers/line_styles_reference.html
    linestyles = ['-', '--', '-.', ':']
    # plot accumulate pop size
    fig, ax = plt.subplots(figsize=(9, 8))
    mark_idx = 0
    for method, gen_hyperv in viewitems(hyperv):
        if not isinstance(gen_hyperv, list):
            continue
        xdata = gen_hyperv[0]
        ydata = gen_hyperv[1]
        plt.plot(xdata, ydata, linestyle=linestyles[mark_idx], color='black',
                 label=method, linewidth=2)
        mark_idx += 1
    plt.legend(fontsize=plot_cfg.legend_fsize, loc=4)
    xaxis = plt.gca().xaxis
    yaxis = plt.gca().yaxis
    for xticklabel in xaxis.get_ticklabels():
        xticklabel.set_fontsize(plot_cfg.tick_fsize)
    for yticklabel in yaxis.get_ticklabels():
        yticklabel.set_fontsize(plot_cfg.tick_fsize)
    plt.xlabel(generation_str, fontsize=plot_cfg.axislabel_fsize)
    plt.ylabel(hyperv_str, fontsize=plot_cfg.axislabel_fsize)
    ax.set_xlim(left=0, right=ax.get_xlim()[1] + 2)
    if 'bottom' in hyperv:
        ax.set_ylim(bottom=hyperv['bottom'])
    if 'top' in hyperv:
        ax.set_ylim(top=hyperv['top'])
    plt.tight_layout()
    save_png_eps(plt, ws, 'hypervolume', plot_cfg)
    # close current plot in case of 'figure.max_open_warning'
    plt.cla()
    plt.clf()
    plt.close()

    return True
Exemple #4
0
    def __init__(self, cf, method='morris'):
        """Initialization."""
        self.method = method
        # 1. SEIMS model related
        self.model = ParseSEIMSConfig(cf)
        # 2. Common settings of parameters sensitivity analysis
        if 'PSA_Settings' not in cf.sections():
            raise ValueError(
                "[PSA_Settings] section MUST be existed in *.ini file.")

        self.evaluate_params = list()
        if cf.has_option('PSA_Settings', 'evaluateparam'):
            eva_str = cf.get('PSA_Settings', 'evaluateparam')
            self.evaluate_params = StringClass.split_string(eva_str, ',')
        else:
            self.evaluate_params = ['Q']  # Default

        self.param_range_def = 'morris_param_rng.def'  # Default
        if cf.has_option('PSA_Settings', 'paramrngdef'):
            self.param_range_def = cf.get('PSA_Settings', 'paramrngdef')
        self.param_range_def = self.model.model_dir + os.path.sep + self.param_range_def
        if not FileClass.is_file_exists(self.param_range_def):
            raise IOError('Ranges of parameters MUST be provided!')

        if not (cf.has_option('PSA_Settings', 'psa_time_start')
                and cf.has_option('PSA_Settings', 'psa_time_end')):
            raise ValueError(
                "Start and end time of PSA MUST be specified in [PSA_Settings]."
            )
        try:
            # UTCTIME
            tstart = cf.get('PSA_Settings', 'psa_time_start')
            tend = cf.get('PSA_Settings', 'psa_time_end')
            self.psa_stime = StringClass.get_datetime(tstart)
            self.psa_etime = StringClass.get_datetime(tend)
        except ValueError:
            raise ValueError('The time format MUST be"YYYY-MM-DD HH:MM:SS".')
        if self.psa_stime >= self.psa_etime:
            raise ValueError("Wrong time settings in [PSA_Settings]!")

        # 3. Parameters settings for specific sensitivity analysis methods
        self.morris = None
        self.fast = None
        if self.method == 'fast':
            self.fast = FASTConfig(cf)
            self.psa_outpath = '%s/PSA_FAST_N%dM%d' % (
                self.model.model_dir, self.fast.N, self.fast.M)
        elif self.method == 'morris':
            self.morris = MorrisConfig(cf)
            self.psa_outpath = '%s/PSA_Morris_N%dL%d' % (
                self.model.model_dir, self.morris.N, self.morris.num_levels)
        # 4. (Optional) Plot settings for matplotlib
        self.plot_cfg = PlotConfig(cf)

        # Do not remove psa_outpath if already existed
        UtilClass.mkdir(self.psa_outpath)
        self.outfiles = PSAOutputs(self.psa_outpath)
Exemple #5
0
def plot_2d_scatter(xlist,  # type: List[float] # X coordinates
                    ylist,  # type: List[float] # Y coordinates
                    title,  # type: AnyStr # Main title of the figure
                    xlabel,  # type: AnyStr # X-axis label
                    ylabel,  # type: AnyStr # Y-axis label
                    ws,  # type: AnyStr # Full path of the destination directory
                    filename,  # type: AnyStr # File name without suffix (e.g., jpg, eps)
                    subtitle='',  # type: AnyStr # Subtitle
                    cn=False,  # type: bool # Use Chinese or not. Deprecated!
                    xmin=None,  # type: Optional[float] # Left min X value
                    xmax=None,  # type: Optional[float] # Right max X value
                    xstep=None,  # type: Optional[float] # X interval
                    ymin=None,  # type: Optional[float] # Bottom min Y value
                    ymax=None,  # type: Optional[float] # Up max Y value
                    ystep=None,  # type: Optional[float] # Y interval
                    plot_cfg=None  # type: Optional[PlotConfig]
                    ):
    # type: (...) -> None
    """Scatter plot of 2D points.

    Todo: The size of the point may be vary with the number of points.
    """
    if plot_cfg is None:
        plot_cfg = PlotConfig()
    plt.rcParams['font.family'] = plot_cfg.font_name
    plt.figure()
    plt.title('%s\n' % title, color='red', fontsize=plot_cfg.title_fsize)
    plt.xlabel(xlabel, fontsize=plot_cfg.axislabel_fsize)
    plt.ylabel(ylabel, fontsize=plot_cfg.axislabel_fsize)
    plt.scatter(xlist, ylist, c='r', alpha=0.8, s=12)
    if xmax is not None:
        plt.xlim(right=xmax)
    if xmin is not None:
        plt.xlim(left=xmin)
    if xstep is not None:
        xmin, xmax = plt.xlim()
        plt.xticks(numpy.arange(xmin, xmax + xstep * 0.99, step=xstep),
                   fontsize=plot_cfg.tick_fsize)
    if ymax is not None:
        plt.ylim(top=ymax)
    if ymin is not None:
        plt.ylim(bottom=ymin)
    if ystep is not None:
        ymin, ymax = plt.ylim()
        plt.yticks(numpy.arange(ymin, ymax + ystep * 0.99, step=ystep),
                   fontsize=plot_cfg.tick_fsize)

    if subtitle != '':
        plt.title(subtitle, color='green', fontsize=plot_cfg.label_fsize, loc='right')
    plt.tight_layout()
    save_png_eps(plt, ws, filename, plot_cfg)
    # close current plot in case of 'figure.max_open_warning'
    plt.cla()
    plt.clf()
    plt.close()
Exemple #6
0
def plot_hypervolume_multiple(method_paths, ws, cn=False, plot_cfg=None):
    # type: (Dict[AnyStr, AnyStr], AnyStr, bool, Optional[PlotConfig]) -> bool
    """Plot hypervolume of multiple optimization methods

    Args:
        method_paths: Dict, key is method name and value is full path of the directory
        ws: Full path of the destination directory
        cn: (Optional) Use Chinese. Deprecated!
        plot_cfg: (Optional) Plot settings for matplotlib
    """
    hyperv = OrderedDict()  # type: Dict[AnyStr, List[List[int], List[float]]]
    for k, v in viewitems(method_paths):
        v = v + os.path.sep + 'hypervolume.txt'
        genids, nmodels, hv = read_hypervolume(v)
        hyperv[k] = [genids[:], hv[:]]
    if plot_cfg is None:
        plot_cfg = PlotConfig()
    return plot_hypervolumes(hyperv, ws, cn, plot_cfg=plot_cfg)
Exemple #7
0
    def __init__(self, cf, method='nsga2'):
        # type: (ConfigParser, str) -> None
        """Initialization."""
        # 1. SEIMS model related
        self.model = ParseSEIMSConfig(cf)

        # 2. Common settings of auto-calibration
        if 'CALI_Settings' not in cf.sections():
            raise ValueError(
                "[CALI_Settings] section MUST be existed in *.ini file.")
        self.param_range_def = 'cali_param_rng.def'
        if cf.has_option('CALI_Settings', 'paramrngdef'):
            self.param_range_def = cf.get('CALI_Settings', 'paramrngdef')
        self.param_range_def = self.model.model_dir + os.path.sep + self.param_range_def
        if not FileClass.is_file_exists(self.param_range_def):
            raise IOError('Ranges of parameters MUST be provided!')

        # UTCTIME of calibration and validation (optional) periods
        if not (cf.has_option('CALI_Settings', 'cali_time_start')
                and cf.has_option('CALI_Settings', 'cali_time_end')):
            raise ValueError("Start and end time of Calibration "
                             "MUST be specified in [CALI_Settings].")
        self.cali_stime = parse_datetime_from_ini(cf, 'CALI_Settings',
                                                  'cali_time_start')
        self.cali_etime = parse_datetime_from_ini(cf, 'CALI_Settings',
                                                  'cali_time_end')
        self.vali_stime = parse_datetime_from_ini(cf, 'CALI_Settings',
                                                  'vali_time_start')
        self.vali_etime = parse_datetime_from_ini(cf, 'CALI_Settings',
                                                  'vali_time_end')
        self.calc_validation = True if self.vali_stime and self.vali_etime else False
        if self.cali_stime >= self.cali_etime or (
                self.calc_validation and self.vali_stime >= self.vali_etime):
            raise ValueError("Wrong time settings in [CALI_Settings]!")

        # 3. Parameters settings for specific optimization algorithm
        self.opt_mtd = method
        self.opt = None
        if self.opt_mtd == 'nsga2':
            self.opt = ParseNSGA2Config(cf, self.model.model_dir,
                                        'CALI_NSGA2_Gen_%d_Pop_%d')

        # 4. (Optional) Plot settings for matplotlib
        self.plot_cfg = PlotConfig(cf)
Exemple #8
0
    def __init__(self, cf):
        """Initialization."""
        # 1. SEIMS model related
        self.model_cfg = ParseSEIMSConfig(cf)
        # 2. Parameters
        self.plt_subbsnid = -1
        self.plot_vars = list()
        if 'PARAMETERS' in cf.sections():
            self.plt_subbsnid = cf.getint('PARAMETERS', 'plot_subbasinid')
            plt_vars_str = cf.get('PARAMETERS', 'plot_variables')
        else:
            raise ValueError(
                "[PARAMETERS] section MUST be existed in *.ini file.")
        if self.plt_subbsnid < 0:
            raise ValueError(
                "PLOT_SUBBASINID must be greater or equal than 0.")
        if plt_vars_str != '':
            self.plot_vars = StringClass.split_string(plt_vars_str, [',', ' '])
        else:
            raise ValueError("PLOT_VARIABLES illegal defined in [PARAMETERS]!")

        # 3. Optional_Parameters
        if 'OPTIONAL_PARAMETERS' not in cf.sections():
            raise ValueError(
                "[OPTIONAL_PARAMETERS] section MUST be existed in *.ini file.")
        # UTCTIME
        self.cali_stime = parse_datetime_from_ini(cf, 'OPTIONAL_PARAMETERS',
                                                  'cali_time_start')
        self.cali_etime = parse_datetime_from_ini(cf, 'OPTIONAL_PARAMETERS',
                                                  'cali_time_end')
        self.vali_stime = parse_datetime_from_ini(cf, 'OPTIONAL_PARAMETERS',
                                                  'vali_time_start')
        self.vali_etime = parse_datetime_from_ini(cf, 'OPTIONAL_PARAMETERS',
                                                  'vali_time_end')

        if not self.cali_stime or not self.cali_etime or self.cali_stime >= self.cali_etime:
            raise ValueError(
                "Wrong time settings of calibration in [OPTIONAL_PARAMETERS]!")
        if self.vali_stime and self.vali_etime and self.vali_stime >= self.vali_etime:
            raise ValueError(
                "Wrong time settings of validation in [OPTIONAL_PARAMETERS]!")
        # 4. Plot settings based on matplotlib
        self.plot_cfg = PlotConfig(cf)
Exemple #9
0
    def __init__(self, cf, method='nsga2'):
        # type: (ConfigParser, str) -> None
        """Initialization."""
        # 1. SEIMS model related
        self.model = ParseSEIMSConfig(cf)  # type: ParseSEIMSConfig

        # 2. Common settings of BMPs scenario
        self.eval_stime = None  # type: Optional[datetime]
        self.eval_etime = None  # type: Optional[datetime]
        self.worst_econ = 0.
        self.worst_env = 0.
        self.runtime_years = 0.
        self.export_sce_txt = False
        self.export_sce_tif = False
        if 'Scenario_Common' not in cf.sections():
            raise ValueError(
                '[Scenario_Common] section MUST be existed in *.ini file.')
        self.eval_stime = parse_datetime_from_ini(cf, 'Scenario_Common',
                                                  'eval_time_start')
        self.eval_etime = parse_datetime_from_ini(cf, 'Scenario_Common',
                                                  'eval_time_end')
        self.worst_econ = cf.getfloat('Scenario_Common', 'worst_economy')
        self.worst_env = cf.getfloat('Scenario_Common', 'worst_environment')
        self.runtime_years = cf.getfloat('Scenario_Common', 'runtime_years')
        if cf.has_option('Scenario_Common', 'export_scenario_txt'):
            self.export_sce_txt = cf.getboolean('Scenario_Common',
                                                'export_scenario_txt')
        if cf.has_option('Scenario_Common', 'export_scenario_tif'):
            self.export_sce_tif = cf.getboolean('Scenario_Common',
                                                'export_scenario_tif')

        # 3. Application specific setting section [BMPs]
        # Selected BMPs, the key is BMPID, and value is the BMP information dict
        self.bmps_info = dict(
        )  # type: Dict[int, Dict[AnyStr, Union[int, float, AnyStr, List[Union[int, float, AnyStr]]]]]
        # BMPs to be constant for generated scenarios during optimization, same format with bmps_info
        self.bmps_retain = dict(
        )  # type: Dict[int, Dict[AnyStr, Union[int, float, AnyStr, List[Union[int, float, AnyStr]]]]]
        self.eval_info = dict(
        )  # type: Dict[AnyStr, Union[int, float, AnyStr]]
        self.bmps_cfg_unit = 'CONNFIELD'  # type: AnyStr
        self.bmps_cfg_method = 'RAND'  # type: AnyStr
        if 'BMPs' not in cf.sections():
            raise ValueError(
                '[BMPs] section MUST be existed for specific scenario analysis.'
            )

        bmpsinfostr = cf.get('BMPs', 'bmps_info')
        self.bmps_info = UtilClass.decode_strs_in_dict(json.loads(bmpsinfostr))
        if cf.has_option('BMPs', 'bmps_retain'):
            bmpsretainstr = cf.get('BMPs', 'bmps_retain')
            self.bmps_retain = json.loads(bmpsretainstr)
            self.bmps_retain = UtilClass.decode_strs_in_dict(self.bmps_retain)
        evalinfostr = cf.get('BMPs', 'eval_info')
        self.eval_info = UtilClass.decode_strs_in_dict(json.loads(evalinfostr))
        bmpscfgunitstr = cf.get('BMPs', 'bmps_cfg_units')
        bmpscfgunitdict = UtilClass.decode_strs_in_dict(
            json.loads(bmpscfgunitstr))
        for unitname, unitcfg in viewitems(bmpscfgunitdict):
            self.bmps_cfg_unit = unitname
            if self.bmps_cfg_unit not in BMPS_CFG_UNITS:
                raise ValueError('BMPs configuration unit MUST be '
                                 'one of %s' % BMPS_CFG_UNITS.__str__())
            if not isinstance(unitcfg, dict):
                raise ValueError(
                    'The value of BMPs configuration unit MUST be dict value!')
            for cfgname, cfgvalue in viewitems(unitcfg):
                for bmpid, bmpdict in viewitems(self.bmps_info):
                    if cfgname in bmpdict:
                        continue
                    self.bmps_info[bmpid][cfgname] = cfgvalue
            break

        if cf.has_option('BMPs', 'bmps_cfg_method'):
            self.bmps_cfg_method = cf.get('BMPs', 'bmps_cfg_method')
            if self.bmps_cfg_method not in BMPS_CFG_METHODS:
                print('BMPs configuration method MUST be one of %s' %
                      BMPS_CFG_METHODS.__str__())
                self.bmps_cfg_method = 'RAND'

        # Check the validation of configuration unit and method
        if self.bmps_cfg_method not in BMPS_CFG_PAIR.get(self.bmps_cfg_unit):
            raise ValueError('BMPs configuration method %s '
                             'is not supported on unit %s' %
                             (self.bmps_cfg_method, self.bmps_cfg_unit))

        # Optimize boundary of BMP configuration unit
        self.boundary_adaptive = False
        self.boundary_adaptive_threshs = None
        if cf.has_option('BMPs', 'bmps_cfg_units_opt'):
            self.boundary_adaptive = cf.getboolean('BMPs',
                                                   'bmps_cfg_units_opt')
        if cf.has_option('BMPs', 'boundary_adaptive_threshold'):
            tstr = cf.get('BMPs', 'boundary_adaptive_threshold')
            self.boundary_adaptive_threshs = StringClass.extract_numeric_values_from_string(
                tstr)
            if 0 not in self.boundary_adaptive_threshs:
                self.boundary_adaptive_threshs.append(
                    0)  # 0 means no adjustment of boundary
            for tmp_thresh in self.boundary_adaptive_threshs:
                if -1 * tmp_thresh not in self.boundary_adaptive_threshs:
                    self.boundary_adaptive_threshs.append(-1 * tmp_thresh)

        # 4. Parameters settings for specific optimization algorithm
        self.opt_mtd = method
        self.opt = None  # type: Union[ParseNSGA2Config, None]
        if self.opt_mtd == 'nsga2':
            self.opt = ParseNSGA2Config(
                cf, self.model.model_dir,
                'SA_NSGA2_%s_%s' % (self.bmps_cfg_unit, self.bmps_cfg_method))
        # Using the existed population derived from previous scenario optimization
        self.initial_byinput = cf.getboolean(self.opt_mtd.upper(), 'inputpopulation') if \
            cf.has_option(self.opt_mtd.upper(), 'inputpopulation') else False
        self.input_pareto_file = None
        self.input_pareto_gen = -1
        if cf.has_option(self.opt_mtd.upper(), 'paretofrontsfile'):
            self.input_pareto_file = cf.get(self.opt_mtd.upper(),
                                            'paretofrontsfile')
        if cf.has_option(self.opt_mtd.upper(), 'generationselected'):
            self.input_pareto_gen = cf.getint(self.opt_mtd.upper(),
                                              'generationselected')

        self.scenario_dir = self.opt.out_dir + os.path.sep + 'Scenarios'
        UtilClass.rmmkdir(self.scenario_dir)

        # 5. (Optional) Plot settings for matplotlib
        self.plot_cfg = PlotConfig(cf)
Exemple #10
0
def calculate_95ppu(sim_obs_data, sim_data, outdir, gen_num,
                    vali_sim_obs_data=None, vali_sim_data=None,
                    plot_cfg=None  # type: Optional[PlotConfig]
                    ):
    """Calculate 95% prediction uncertainty and plot the hydrographs."""
    if plot_cfg is None:
        plot_cfg = PlotConfig()
    plt.rcParams['xtick.direction'] = 'out'
    plt.rcParams['ytick.direction'] = 'out'
    plt.rcParams['font.family'] = plot_cfg.font_name
    plt.rcParams['timezone'] = 'UTC'
    plt.rcParams['mathtext.fontset'] = 'custom'
    plt.rcParams['mathtext.it'] = 'STIXGeneral:italic'
    plt.rcParams['mathtext.bf'] = 'STIXGeneral:italic:bold'
    if len(sim_data) < 2:
        return
    var_name = sim_obs_data[0]['var_name']
    for idx, var in enumerate(var_name):
        plot_validation = False
        if vali_sim_obs_data and vali_sim_data and var in vali_sim_obs_data[0]['var_name']:
            plot_validation = True
        ylabel_str = var
        if var in ['Q', 'QI', 'QG', 'QS']:
            ylabel_str += ' (m$^3$/s)'
        elif 'CONC' in var.upper():  # Concentrate
            if 'SED' in var.upper():
                ylabel_str += ' (g/L)'
            else:
                ylabel_str += ' (mg/L)'
        elif 'SED' in var.upper():  # amount
            ylabel_str += ' (kg)'
        cali_obs_dates = sim_obs_data[0][var]['UTCDATETIME'][:]
        if is_string(cali_obs_dates[0]):
            cali_obs_dates = [StringClass.get_datetime(s) for s in cali_obs_dates]
        obs_dates = cali_obs_dates[:]
        order = 1  # By default, the calibration period is before the validation period.
        if plot_validation:
            vali_obs_dates = vali_sim_obs_data[0][var]['UTCDATETIME']
            if is_string(vali_obs_dates[0]):
                vali_obs_dates = [StringClass.get_datetime(s) for s in vali_obs_dates]
            if vali_obs_dates[-1] <= cali_obs_dates[0]:
                order = 0
                obs_dates = vali_obs_dates + obs_dates
            else:
                obs_dates += vali_obs_dates
        obs_data = sim_obs_data[0][var]['Obs'][:]
        if plot_validation:
            if order:
                obs_data += vali_sim_obs_data[0][var]['Obs'][:]
            else:
                obs_data = vali_sim_obs_data[0][var]['Obs'][:] + obs_data

        cali_sim_dates = list(sim_data[0].keys())
        if is_string(cali_sim_dates[0]):
            cali_sim_dates = [StringClass.get_datetime(s) for s in cali_sim_dates]
        sim_dates = cali_sim_dates[:]
        if plot_validation:
            vali_sim_dates = list(vali_sim_data[0].keys())
            if is_string(vali_sim_dates[0]):
                vali_sim_dates = [StringClass.get_datetime(s) for s in vali_sim_dates]
            if order:
                sim_dates += vali_sim_dates
            else:
                sim_dates = vali_sim_dates + sim_dates
        sim_data_list = list()
        caliBestIdx = -1
        caliBestNSE = -9999.
        for idx2, ind in enumerate(sim_data):
            tmp = numpy.array(list(ind.values()))
            tmp = tmp[:, idx]
            if sim_obs_data[idx2][var]['NSE'] > caliBestNSE:
                caliBestNSE = sim_obs_data[idx2][var]['NSE']
                caliBestIdx = idx2
            tmpsim = tmp.tolist()
            if plot_validation:
                tmp_data = numpy.array(list(vali_sim_data[idx2].values()))[:, idx].tolist()
                if order:
                    tmpsim += tmp_data
                else:
                    tmpsim = tmp_data + tmpsim
            sim_data_list.append(tmpsim)

        sim_best = numpy.array(list(sim_data[caliBestIdx].values()))[:, idx]
        sim_best = sim_best.tolist()
        if plot_validation:
            tmp_data = numpy.array(list(vali_sim_data[caliBestIdx].values()))[:, idx].tolist()
            if order:
                sim_best += tmp_data
            else:
                sim_best = tmp_data + sim_best
        sim_data_list = numpy.array(sim_data_list)
        ylows = numpy.percentile(sim_data_list, 2.5, 0, interpolation='nearest')
        yups = numpy.percentile(sim_data_list, 97.5, 0, interpolation='nearest')

        def calculate_95ppu_efficiency(obs_data_list, obs_dates_list, sim_dates_list):
            # type: (...) -> (float, float)
            count = 0
            ylows_obs = list()
            yups_obs = list()
            for oi, ov in enumerate(obs_data_list):
                try:
                    si = sim_dates_list.index(obs_dates_list[oi])
                    ylows_obs.append(ylows[si])
                    yups_obs.append(yups[si])
                    if ylows[si] <= ov <= yups[si]:
                        count += 1
                except Exception:
                    continue
            p = float(count) / len(obs_data_list)
            ylows_obs = numpy.array(ylows_obs)
            yups_obs = numpy.array(yups_obs)
            r = numpy.mean(yups_obs - ylows_obs) / numpy.std(numpy.array(obs_data_list))
            return p, r

        # concatenate text
        p_value, r_value = calculate_95ppu_efficiency(sim_obs_data[0][var]['Obs'],
                                                      cali_obs_dates,
                                                      list(sim_data[0].keys()))
        txt = 'P-factor: %.2f\nR-factor: %.2f\n' % (p_value, r_value)
        txt += u'某一最优模拟\n' if plot_cfg.plot_cn else 'One best simulation:\n'
        txt += '    $\mathit{NSE}$: %.2f\n' \
               '    $\mathit{RSR}$: %.2f\n' \
               '    $\mathit{PBIAS}$: %.2f%%\n' \
               '    $\mathit{R^2}$: %.2f' % (sim_obs_data[caliBestIdx][var]['NSE'],
                                             sim_obs_data[caliBestIdx][var]['RSR'],
                                             sim_obs_data[caliBestIdx][var]['PBIAS'],
                                             sim_obs_data[caliBestIdx][var]['R-square'])
        # concatenate text of validation if needed
        vali_txt = ''
        if plot_validation:
            p_value, r_value = calculate_95ppu_efficiency(vali_sim_obs_data[0][var]['Obs'],
                                                          vali_obs_dates,
                                                          list(vali_sim_data[0].keys()))
            vali_txt = 'P-factor: %.2f\nR-factor: %.2f\n\n' % (p_value, r_value)
            vali_txt += '    $\mathit{NSE}$: %.2f\n' \
                        '    $\mathit{RSR}$: %.2f\n' \
                        '    $\mathit{PBIAS}$: %.2f%%\n' \
                        '    $\mathit{R^2}$: %.2f' % (vali_sim_obs_data[caliBestIdx][var]['NSE'],
                                                      vali_sim_obs_data[caliBestIdx][var]['RSR'],
                                                      vali_sim_obs_data[caliBestIdx][var]['PBIAS'],
                                                      vali_sim_obs_data[caliBestIdx][var]['R-square'])
        # plot
        fig, ax = plt.subplots(figsize=(12, 4))
        ax.fill_between(sim_dates, ylows.tolist(), yups.tolist(),
                        color=(0.8, 0.8, 0.8), label='95PPU')
        observed_label = u'实测值' if plot_cfg.plot_cn else 'Observed points'
        ax.scatter(obs_dates, obs_data, marker='.', s=20,
                   color='g', label=observed_label)
        besesim_label = u'最优模拟' if plot_cfg.plot_cn else 'Best simulation'
        ax.plot(sim_dates, sim_best, linestyle='--', color='red', linewidth=1,
                label=besesim_label)
        ax.set_xlim(left=min(sim_dates), right=max(sim_dates))
        ax.set_ylim(bottom=0.)
        date_fmt = mdates.DateFormatter('%m-%d-%y')
        ax.xaxis.set_major_formatter(date_fmt)
        ax.tick_params(axis='x', bottom=True, top=False, length=5, width=2, which='major',
                       labelsize=plot_cfg.tick_fsize)
        ax.tick_params(axis='y', left=True, right=False, length=5, width=2, which='major',
                       labelsize=plot_cfg.tick_fsize)
        plt.xlabel(u'时间' if plot_cfg.plot_cn else 'Date time',
                   fontsize=plot_cfg.axislabel_fsize)
        plt.ylabel(ylabel_str, fontsize=plot_cfg.axislabel_fsize)
        # plot separate dash line
        delta_dt = (sim_dates[-1] - sim_dates[0]) // 9
        delta_dt2 = (sim_dates[-1] - sim_dates[0]) // 35
        sep_time = sim_dates[-1]
        time_pos = [sep_time - delta_dt]
        time_pos2 = [sep_time - 2 * delta_dt]
        ymax, ymin = ax.get_ylim()
        yc = abs(ymax - ymin) * 0.9
        if plot_validation:
            sep_time = vali_sim_dates[0] if vali_sim_dates[0] >= cali_sim_dates[-1] \
                else cali_sim_dates[0]
            cali_vali_labels = [(u'验证期' if plot_cfg.plot_cn else 'Calibration'),
                                (u'率定期' if plot_cfg.plot_cn else 'Validation')]
            if not order:
                cali_vali_labels.reverse()
            time_pos = [sep_time - delta_dt, sep_time + delta_dt2]
            time_pos2 = [sep_time - 2 * delta_dt, sep_time + delta_dt2]
            ax.axvline(sep_time, color='black', linestyle='dashed', linewidth=2)
            plt.text(time_pos[0], yc, cali_vali_labels[0],
                     fontdict={'style': 'italic', 'weight': 'bold',
                               'size': plot_cfg.label_fsize},
                     color='black')
            plt.text(time_pos[1], yc, cali_vali_labels[1],
                     fontdict={'style': 'italic', 'weight': 'bold',
                               'size': plot_cfg.label_fsize},
                     color='black')

        # add legend
        handles, labels = ax.get_legend_handles_labels()
        figorders = [labels.index('95PPU'), labels.index(observed_label),
                     labels.index(besesim_label)]
        ax.legend([handles[idx] for idx in figorders], [labels[idx] for idx in figorders],
                  fontsize=plot_cfg.legend_fsize, loc=2, framealpha=0.8)
        # add text
        cali_pos = time_pos[0] if order else time_pos[1]
        plt.text(cali_pos, yc * 0.5, txt, color='red', fontsize=plot_cfg.label_fsize - 1)
        if plot_validation:
            vali_pos = time_pos[1] if order else time_pos[0]
            plt.text(vali_pos, yc * 0.5, vali_txt, color='red', fontsize=plot_cfg.label_fsize - 1)
        # fig.autofmt_xdate(rotation=0, ha='center')
        plt.tight_layout()
        save_png_eps(plt, outdir, 'Gen%d_95ppu_%s' % (gen_num, var), plot_cfg)
        # close current plot in case of 'figure.max_open_warning'
        plt.cla()
        plt.clf()
        plt.close()
Exemple #11
0
def plot_3d_scatter(xlist,  # type: List[float] # X coordinates
                    ylist,  # type: List[float] # Y coordinates
                    zlist,  # type: List[float] # Z coordinates
                    title,  # type: AnyStr # Main title of the figure
                    xlabel,  # type: AnyStr # X-axis label
                    ylabel,  # type: AnyStr # Y-axis label
                    zlabel,  # type: AnyStr # Z-axis label
                    ws,  # type: AnyStr # Full path of the destination directory
                    filename,  # type: AnyStr # File name without suffix (e.g., jpg, eps)
                    subtitle='',  # type: AnyStr # Subtitle
                    cn=False,  # type: bool # Use Chinese or not
                    xmin=None,  # type: Optional[float] # Left min X value
                    xmax=None,  # type: Optional[float] # Right max X value
                    ymin=None,  # type: Optional[float] # Bottom min Y value
                    ymax=None,  # type: Optional[float] # Up max Y value
                    zmin=None,  # type: Optional[float] # Min Z value
                    zmax=None,  # type: Optional[float] # Max Z value
                    xstep=None,  # type: Optional[float] # X interval
                    ystep=None,  # type: Optional[float] # Y interval
                    zstep=None,  # type: Optional[float] # Z interval
                    plot_cfg=None  # type: Optional[PlotConfig]
                    ):
    # type: (...) -> None
    """Scatter plot of 3D points.
    """
    if plot_cfg is None:
        plot_cfg = PlotConfig()
    plt.rcParams['font.family'] = plot_cfg.font_name
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    plt.suptitle('%s\n' % title, color='red', fontsize=plot_cfg.title_fsize)
    ax.set_xlabel(xlabel, fontsize=plot_cfg.axislabel_fsize)
    ax.set_ylabel(ylabel, fontsize=plot_cfg.axislabel_fsize)
    ax.set_zlabel(zlabel, fontsize=plot_cfg.axislabel_fsize)
    ax.scatter(xlist, ylist, zlist, c='r', s=12)
    for xticklabel in ax.xaxis.get_ticklabels():
        xticklabel.set_fontsize(plot_cfg.tick_fsize)
    for yticklabel in ax.yaxis.get_ticklabels():
        yticklabel.set_fontsize(plot_cfg.tick_fsize)
    for zticklabel in ax.zaxis.get_ticklabels():
        zticklabel.set_fontsize(plot_cfg.tick_fsize)
    if xmax is not None:
        ax.set_xlim(right=xmax)
    if xmin is not None:
        ax.set_xlim(left=xmin)
    if xstep is not None:
        xmin, xmax = ax.get_xlim()
        ax.set_xticks(numpy.arange(xmin, xmax + xstep * 0.99, step=xstep))
    if ymax is not None:
        ax.set_ylim(top=ymax)
    if ymin is not None:
        ax.set_ylim(bottom=ymin)
    if ystep is not None:
        ymin, ymax = ax.get_ylim()
        ax.set_yticks(numpy.arange(ymin, ymax + ystep * 0.99, step=ystep))
    if zmax is not None:
        ax.set_zlim3d(top=zmax)
    if zmin is not None:
        ax.set_zlim3d(bottom=zmin)
    if zstep is not None:
        zmin, zmax = ax.get_zlim()
        ax.set_zticks(numpy.arange(zmin, zmax + zstep * 0.99, step=zstep))

    if subtitle != '':
        plt.title(subtitle, color='green', fontsize=plot_cfg.label_fsize, loc='right')
    plt.tight_layout()
    save_png_eps(plt, ws, filename, plot_cfg)
    # close current plot in case of 'figure.max_open_warning'
    plt.cla()
    plt.clf()
    plt.close()
Exemple #12
0
def plot_hypervolume_single(hypervlog, ws=None, cn=False, plot_cfg=None):
    # type: (AnyStr, Optional[AnyStr], bool, Optional[PlotConfig]) -> bool
    """Plot hypervolume and the newly executed models of each generation.

    Args:
        hypervlog: Full path of the hypervolume log.
        ws: (Optional) Full path of the destination directory
        cn: (Optional) Use Chinese. Deprecated!
        plot_cfg: (Optional) Plot settings for matplotlib
    """
    x, hyperv, nmodel = read_hypervolume(hypervlog)
    if not x or not hyperv:
        print('Error: No available hypervolume data loaded!')
        return False
    if plot_cfg is None:
        plot_cfg = PlotConfig()

    plt.rcParams['xtick.direction'] = 'out'
    plt.rcParams['ytick.direction'] = 'out'
    plt.rcParams['font.family'] = plot_cfg.font_name
    generation_str = (u'进化代数' if plot_cfg.plot_cn else 'Generation')
    hyperv_str = (u'Hypervolume 指数' if plot_cfg.plot_cn else 'Hypervolume index')
    nmodel_str = (u'新运行模型次数' if plot_cfg.plot_cn else 'New model evaluations')

    linestyles = ['-', '--', '-.', ':']
    markers = ['o', 's', 'v', '*']
    fig, ax = plt.subplots(figsize=(10, 6))
    mark_idx = 0
    p1 = ax.plot(x, hyperv, linestyle=linestyles[0], marker=markers[mark_idx],
                 color='black', label=hyperv_str, linewidth=2, markersize=4)
    mark_idx += 1
    plt.xlabel(generation_str, fontsize=plot_cfg.axislabel_fsize)
    plt.ylabel(hyperv_str, fontsize=plot_cfg.axislabel_fsize)
    ax.set_xlim(left=0, right=ax.get_xlim()[1])
    legends = p1

    plt.tight_layout()
    if ws is None:
        ws = os.path.dirname(hypervlog)
    save_png_eps(plt, ws, 'hypervolume', plot_cfg)

    if nmodel:
        # Add right Y-axis
        ax2 = ax.twinx()
        ax.tick_params(axis='x', which='both', bottom=True, top=False,
                       labelsize=plot_cfg.tick_fsize)
        ax2.tick_params(axis='y', length=5, width=2, which='major',
                        labelsize=plot_cfg.tick_fsize)
        ax2.set_ylabel(nmodel_str, fontsize=plot_cfg.axislabel_fsize)
        p2 = ax2.plot(x, nmodel, linestyle=linestyles[0], marker=markers[mark_idx],
                      color='black', label=nmodel_str, linewidth=2, markersize=4)
        legends += p2

    legends_label = [l.get_label() for l in legends]
    ax.legend(legends, legends_label, fontsize=plot_cfg.legend_fsize, loc='center right')

    plt.tight_layout()
    save_png_eps(plt, ws, 'hypervolume_modelruns', plot_cfg)
    # close current plot in case of 'figure.max_open_warning'
    plt.cla()
    plt.clf()
    plt.close()

    return True
Exemple #13
0
def plot_pareto_fronts(pareto_data,
                       # type: Dict[AnyStr, Dict[Union[AnyStr, int], Union[List[List[float]], numpy.ndarray]]]
                       xname,  # type: List[AnyStr, Optional[float], Optional[float]]
                       yname,  # type: List[AnyStr, Optional[float], Optional[float]]
                       gens,  # type: List[Union[AnyStr, int]]
                       ws,  # type: AnyStr
                       plot_cfg=None  # type: Optional[PlotConfig]
                       ):
    # type: (...) -> None
    """
    Plot Pareto fronts of different methods at a same generation for comparision.

    Args:
        pareto_data(OrderedDict)
        xname(list): the first is x-axis name of plot,
                     the second and third values are low and high limits (optional).
        yname(list): see xname
        gens(list): generation to be plotted
        ws(string): workspace for output files
        plot_cfg(PlotConfig): Plot settings for matplotlib
    """
    if plot_cfg is None:
        plot_cfg = PlotConfig()
    if len(xname) < 1 or len(yname) < 1:
        xname = ['x-axis']
        yname = ['y-axis']
    ylabel_str = yname[0]
    xlabel_str = xname[0]
    file_name = '-'.join(list(pareto_data.keys()))

    plt.rcParams['xtick.direction'] = 'out'
    plt.rcParams['ytick.direction'] = 'out'
    plt.rcParams['font.family'] = plot_cfg.font_name

    # Deprecated code for detecting Chinese characters.
    # # Check if xname or yname contains Chinese characters
    # zhPattern = re.compile(u'[\u4e00-\u9fa5]+')
    # if zhPattern.search(xname[0]) or zhPattern.search(yname[0]):
    #     plt.rcParams['font.family'] = 'SimSun'  # 宋体

    markers = ['.', '*', '+', 'x', 'd', 'h', 's', '<', '>']
    colors = ['r', 'b', 'g', 'c', 'm', 'y', 'k', 'k', 'k']

    # plot comparision of Pareto fronts

    # Get max. and mix. values
    max_x = None
    min_x = None
    max_y = None
    min_y = None
    for method, cur_pareto_data in viewitems(pareto_data):
        if 'min' in cur_pareto_data:
            if min_x is None or min_x > cur_pareto_data['min'][0]:
                min_x = cur_pareto_data['min'][0]
            if min_y is None or min_y > cur_pareto_data['min'][1]:
                min_y = cur_pareto_data['min'][1]
        if 'max' in cur_pareto_data:
            if max_x is None or max_x < cur_pareto_data['max'][0]:
                max_x = cur_pareto_data['max'][0]
            if max_y is None or max_y < cur_pareto_data['max'][1]:
                max_y = cur_pareto_data['max'][1]
    newxname = xname[:]
    newyname = yname[:]
    if min_x is not None and max_x is not None and len(newxname) < 2:
        newxname += get_optimal_bounds(min_x, max_x)
    if min_y is not None and max_y is not None and len(newyname) < 2:
        newyname += get_optimal_bounds(min_y, max_y)

    for gen in gens:
        fig, ax = plt.subplots(figsize=(9, 8))
        mark_idx = 0
        gen_existed = True
        xdata_list = list()
        ydata_list = list()
        marker_list = list()
        method_list = list()
        for method, gen_data in viewitems(pareto_data):
            if gen not in gen_data:
                gen_existed = False
                break
            xdata_list.append(numpy.array(gen_data[gen])[:, 0])
            ydata_list.append(numpy.array(gen_data[gen])[:, 1])
            marker_list.append(mark_idx)
            method_list.append(method)
            mark_idx += 1
        if not gen_existed:
            plt.cla()
            plt.clf()
            plt.close()
            continue
        xdata_list.reverse()
        ydata_list.reverse()
        marker_list.reverse()
        method_list.reverse()
        for xdata, ydata, markeridx, method in zip(xdata_list, ydata_list,
                                                   marker_list, method_list):
            plt.scatter(xdata, ydata, marker=markers[markeridx], s=100,
                        color=colors[markeridx], label=method)

        xaxis = plt.gca().xaxis
        yaxis = plt.gca().yaxis
        for xticklabel in xaxis.get_ticklabels():
            xticklabel.set_fontsize(plot_cfg.tick_fsize)
        for yticklabel in yaxis.get_ticklabels():
            yticklabel.set_fontsize(plot_cfg.tick_fsize)
        plt.xlabel(xlabel_str, fontsize=plot_cfg.axislabel_fsize)
        plt.ylabel(ylabel_str, fontsize=plot_cfg.axislabel_fsize)
        # set xy axis limit
        curxlim = ax.get_xlim()
        if len(newxname) >= 3:
            ax.set_xlim(left=newxname[1])
            ax.set_xlim(right=newxname[2])
            if len(newxname) >= 4:
                ax.xaxis.set_ticks(numpy.arange(newxname[1], newxname[2] + newxname[3], newxname[3]))

        curylim = ax.get_ylim()
        if len(newyname) >= 3:
            ax.set_ylim(bottom=newyname[1])
            ax.set_ylim(top=newyname[2])
            if len(newyname) >= 4:
                ax.yaxis.set_ticks(numpy.arange(newyname[1], newyname[2] + newyname[3], newyname[3]))

        handles, labels = ax.get_legend_handles_labels()
        handles.reverse()
        labels.reverse()
        plt.legend(handles, labels, fontsize=plot_cfg.legend_fsize, loc=4)
        # loc 2: upper left, 4: lower right
        plt.tight_layout()
        save_png_eps(plt, ws, 'gen%d' % gen, plot_cfg)

        # close current plot in case of 'figure.max_open_warning'
        plt.cla()
        plt.clf()
        plt.close()
Exemple #14
0
def plot_pareto_fronts_multigenerations(data,
                                        # type: Dict[Union[AnyStr, int], Union[List[List[float]], numpy.ndarray]]
                                        labels,
                                        # type: List[AnyStr] # Labels (axis names) with length of ncols
                                        ws,  # type: AnyStr # Full path of the destination directory
                                        gen_ids,  # type: List[int] # Selected generation IDs
                                        title,  # type: AnyStr # Main title of the figure
                                        lowers=None,
                                        # type: Optional[numpy.ndarray, List[float]] # Lower values of each axis
                                        uppers=None,
                                        # type: Optional[numpy.ndarray, List[float]] # Higher values of each axis
                                        steps=None,
                                        # type: Optional[numpy.ndarray, List[float]] # Intervals of each axis
                                        cn=False,  # type: bool # Use Chinese or not. Deprecated!
                                        plot_cfg=None  # type: Optional[PlotConfig] # Plot settings for matplotlib
                                        ):
    # type: (...) -> None
    """Plot Pareto fronts of selected generations."""
    filename = 'Pareto_Generations_%s' % ('-'.join(repr(i) for i in gen_ids))
    if plot_cfg is None:
        plot_cfg = PlotConfig()
    plt.rcParams['font.family'] = plot_cfg.font_name
    if plot_cfg.plot_cn:
        filename += '_cn'

    fig, ax = plt.subplots(figsize=(9, 8))
    # ColorMaps: https://matplotlib.org/tutorials/colors/colormaps.html
    cmap = cm.get_cmap('gist_heat')  # one of the sequential colormaps
    for idx, gen in enumerate(gen_ids):
        if gen not in data:
            continue
        xdata = numpy.array(data[gen])[:, 0]  # first column
        ydata = numpy.array(data[gen])[:, 1]  # second column
        plt.scatter(xdata, ydata, marker='.', s=100,
                    color=cmap(0.8 * (len(gen_ids) - idx) / len(gen_ids)),
                    label=(u'第 %d 代' if plot_cfg.plot_cn else 'Generation %d') % gen)
    xaxis = plt.gca().xaxis
    yaxis = plt.gca().yaxis
    for xticklabel in xaxis.get_ticklabels():
        xticklabel.set_fontsize(plot_cfg.tick_fsize)
    for yticklabel in yaxis.get_ticklabels():
        yticklabel.set_fontsize(plot_cfg.tick_fsize)
    plt.xlabel(labels[0], fontsize=plot_cfg.axislabel_fsize)
    plt.ylabel(labels[1], fontsize=plot_cfg.axislabel_fsize)
    # set xy axis limit
    if lowers is not None:
        ax.set_xlim(left=lowers[0])
        ax.set_ylim(bottom=lowers[1])
    if uppers is not None:
        ax.set_xlim(right=uppers[0])
        ax.set_ylim(top=uppers[1])
    if steps is not None:
        xmin, xmax = plt.xlim()
        plt.xticks(numpy.arange(xmin, xmax + steps[0] * 0.99, step=steps[0]))
        ymin, ymax = plt.ylim()
        plt.yticks(numpy.arange(ymin, ymax + steps[1] * 0.99, step=steps[1]))

    plt.legend(fontsize=plot_cfg.legend_fsize, loc=2)  # loc 2: upper left, 4: lower right, 0: best
    plt.tight_layout()
    save_png_eps(plt, ws, filename, plot_cfg)

    # close current plot in case of 'figure.max_open_warning'
    plt.cla()
    plt.clf()
    plt.close()
Exemple #15
0
def plot_pareto_front_single(data,
                             # type: Union[numpy.ndarray, List[List[float]]] # [nrows * ncols] array
                             labels,
                             # type: List[AnyStr] # Labels (axis names) with length of ncols
                             ws,  # type: AnyStr # Full path of the destination directory
                             gen_id,  # type: Union[int, AnyStr] # Generation ID
                             title,  # type: AnyStr # Main title of the figure
                             lowers=None,
                             # type: Optional[numpy.ndarray, List[float]] # Lower values of each axis
                             uppers=None,
                             # type: Optional[numpy.ndarray, List[float]] # Higher values of each axis
                             steps=None,
                             # type: Optional[numpy.ndarray, List[float]] # Intervals of each axis
                             cn=False,  # type: bool # Use Chinese or not. Deprecated!
                             plot_cfg=None  # type: Optional[PlotConfig]
                             ):
    # type: (...) -> bool
    """
    Plot 2D or 3D pareto front graphs.
    Args:
        data: 2-dimension array, nrows * ncols
        labels: Labels (axis names) list, the length should be equal to ncols
        ws: Workspace path
        gen_id: Generation ID
        title: Title
        lowers: (Optional) Lower values of each axis. Default is None.
        uppers: (Optional) Upper values of each axis. Default is None.
        steps: (Optional) Major ticks of each axis. Default is None.
        cn: (Optional) Use Chinese. Deprecated. Please use plot_cfg=PlotConfig instead.
        plot_cfg: (Optional) Plot settings for matplotlib.
    """
    if not isinstance(data, numpy.ndarray):
        data = numpy.array(data)
    pop_size, axis_size = data.shape
    if axis_size <= 1:
        print('Error: The size of fitness values MUST >= 2 to plot 2D graphs!')
        return False
    if len(labels) != axis_size:
        print('Error: The size of fitness values and labels are not consistent!')
        return False
    if lowers is not None and len(lowers) != axis_size:
        print('Warning: The size of fitness values and lowers are not consistent!')
        lowers = None
    if uppers is not None and len(uppers) != axis_size:
        print('Warning: The size of fitness values and uppers are not consistent!')
        uppers = None
    if steps is not None and len(steps) != axis_size:
        print('Warning: The size of fitness values and steps are not consistent!')
        steps = None
    if plot_cfg is None:
        plot_cfg = PlotConfig()
    cn = plot_cfg.plot_cn
    if isinstance(gen_id, int):
        subtitle = '\nGeneration: %d, Population: %d' % (gen_id, pop_size)
        if cn:
            subtitle = u'\n代数: %d, 个体数: %d' % (gen_id, pop_size)
    else:
        subtitle = '\nAll generations, Population: %d' % pop_size
        if cn:
            subtitle = u'\n所有进化代数, 个体数: %d' % pop_size
    # 2D plot
    comb_2d = list(itertools.combinations(range(axis_size), 2))
    for comb in comb_2d:
        x_idx = comb[0]
        y_idx = comb[1]
        x_min = None
        x_max = None
        x_step = None
        y_min = None
        y_max = None
        y_step = None
        if lowers is not None:
            x_min = lowers[x_idx]
            y_min = lowers[y_idx]
        if uppers is not None:
            x_max = uppers[x_idx]
            y_max = uppers[y_idx]
        if steps is not None:
            x_step = steps[x_idx]
            y_step = steps[y_idx]
        maintitle = '%s (%s, %s)' % (title, labels[x_idx], labels[y_idx])
        dirname = 'Pareto_%s-%s' % (labels[x_idx], labels[y_idx])
        tmpws = ws + os.sep + dirname
        if not os.path.exists(tmpws):
            os.mkdir(tmpws)
        filename = 'Pareto_Gen_%s_Pop_%d' % (str(gen_id), pop_size)
        if cn:
            filename += '_cn'
        plot_2d_scatter(data[:, x_idx], data[:, y_idx], maintitle,
                        labels[x_idx], labels[y_idx], tmpws, filename, subtitle,
                        xmin=x_min, xmax=x_max, ymin=y_min, ymax=y_max,
                        xstep=x_step, ystep=y_step, plot_cfg=plot_cfg)
    if axis_size >= 3:
        # 3D plot
        comb_3d = list(itertools.combinations(range(axis_size), 3))
        for comb in comb_3d:
            x_idx = comb[0]
            y_idx = comb[1]
            z_idx = comb[2]
            x_min = None
            x_max = None
            x_step = None
            y_min = None
            y_max = None
            y_step = None
            z_min = None
            z_max = None
            z_step = None
            if lowers is not None:
                x_min = lowers[x_idx]
                y_min = lowers[y_idx]
                z_min = lowers[z_idx]
            if uppers is not None:
                x_max = uppers[x_idx]
                y_max = uppers[y_idx]
                z_max = uppers[z_idx]
            if steps is not None:
                x_step = steps[x_idx]
                y_step = steps[y_idx]
                z_step = steps[z_idx]
            maintitle = '%s (%s, %s, %s)' % (title, labels[x_idx], labels[y_idx], labels[z_idx])
            dirname = 'Pareto_%s-%s-%s' % (labels[x_idx], labels[y_idx], labels[z_idx])
            tmpws = ws + os.sep + dirname
            if not os.path.exists(tmpws):
                os.mkdir(tmpws)
            filename = 'Pareto_Gen_%s_Pop_%d' % (str(gen_id), pop_size)
            if cn:
                filename += '_cn'
            plot_3d_scatter(data[:, x_idx], data[:, y_idx], data[:, z_idx], maintitle,
                            labels[x_idx], labels[y_idx], labels[z_idx],
                            tmpws, filename, subtitle,
                            xmin=x_min, xmax=x_max, ymin=y_min, ymax=y_max, zmin=z_min, zmax=z_max,
                            xstep=x_step, ystep=y_step, zstep=z_step, plot_cfg=plot_cfg)
    return True
Exemple #16
0
def empirical_cdf(out_values,
                  subsections,
                  input_sample,
                  names,
                  levels,
                  outpath,
                  outname,
                  param_dict,
                  plot_cfg=None):
    """Visualize the empirical cumulative distribution function(CDF)
    of the given variable (x) and subsections of y.

    """
    # prepare data
    if not isinstance(out_values, numpy.ndarray):
        out_values = numpy.array(out_values)
    out_max = numpy.max(out_values)
    out_min = numpy.min(out_values)
    if isinstance(subsections, int):
        if subsections <= 0:
            raise ValueError(
                'subsections MUST be a integer greater than 0, or list.')
        step = (out_max - out_min) / subsections
        subsections = numpy.arange(out_min, out_max + step, step)
    if isinstance(subsections, list) and len(subsections) == 1:  # e.g., [0]
        section_pt = subsections[0]
        if out_min < section_pt < out_max:
            subsections = [out_min, section_pt, out_max]
        else:
            subsections = [out_min, out_max]
    labels = list()
    new_input_sample = list()
    for i in range(1, len(subsections)):
        decimal1 = 0 if int(subsections[i - 1]) == float(subsections[i -
                                                                     1]) else 2
        decimal2 = 0 if int(subsections[i]) == float(subsections[i]) else 2
        if out_max == subsections[i] and out_min == subsections[i - 1]:
            labels.append('%s=<y<=%s' %
                          ('{0:.{1}f}'.format(subsections[i - 1], decimal1),
                           '{0:.{1}f}'.format(subsections[i], decimal2)))
            zone = numpy.where((subsections[i - 1] <= out_values)
                               & (out_values <= subsections[i]))
        elif out_max == subsections[i]:
            labels.append('y>=%s' %
                          '{0:.{1}f}'.format(subsections[i - 1], decimal1))
            zone = numpy.where(subsections[i - 1] <= out_values)
        elif out_min == subsections[i - 1]:
            labels.append('y<%s' %
                          ('{0:.{1}f}'.format(subsections[i], decimal2)))
            zone = numpy.where(out_values < subsections[i])
        else:
            labels.append('%s=<y<%s' %
                          ('{0:.{1}f}'.format(subsections[i - 1], decimal1),
                           '{0:.{1}f}'.format(subsections[i], decimal2)))
            zone = numpy.where((subsections[i - 1] <= out_values)
                               & (out_values < subsections[i]))
        new_input_sample.append(input_sample[zone, :][0])

    if plot_cfg is None:
        plot_cfg = PlotConfig()
    plt.rcParams['font.family'] = plot_cfg.font_name
    fig = plt.figure()

    num_vars = len(names)
    row, col = cal_row_col_num(num_vars)
    for var_idx in range(num_vars):
        ax = fig.add_subplot(row, col, var_idx + 1)
        for ii in range(len(labels) - 1, -1, -1):
            ax.hist(new_input_sample[ii][:, var_idx],
                    bins=levels,
                    density=True,
                    cumulative=True,
                    label=labels[ii],
                    **param_dict)
        ax.get_yaxis().set_major_locator(LinearLocator(numticks=5))
        ax.set_ylim(0, 1)
        ax.set_title('%s' % (names[var_idx]), fontsize=plot_cfg.title_fsize)
        ax.get_xaxis().set_major_locator(LinearLocator(numticks=3))
        ax.tick_params(
            axis='x',  # changes apply to the x-axis
            which='both',  # both major and minor ticks are affected
            bottom=True,  # ticks along the bottom edge are off
            top=False,  # ticks along the top edge are off
            labelbottom=True  # labels along the bottom edge are off
        )
        ax.tick_params(
            axis='y',  # changes apply to the y-axis
            which='major',  # both major and minor ticks are affected
            length=3,
            right=False)
        if var_idx % col:  # labels along the left edge are off
            ax.tick_params(axis='y', labelleft=False)
        if var_idx == 0:
            ax.legend(loc='lower right',
                      fontsize=plot_cfg.legend_fsize,
                      framealpha=0.8,
                      bbox_to_anchor=(1, 0),
                      borderaxespad=0.2,
                      fancybox=True)
    plt.tight_layout()
    save_png_eps(plt, outpath, outname, plot_cfg)
    # close current plot in case of 'figure.max_open_warning'
    plt.cla()
    plt.clf()
    plt.close()