示例#1
0
def plot_tern(probs,
              subsys,
              title,
              color_dict,
              ax=None,
              legend=False,
              info=True,
              markersize=10):
    if ax is None:
        fig, ax = plt.subplots(figsize=(4.5, 4))
    else:
        fig = plt.gcf()
    ax.set_aspect('equal')
    tax = ternary.TernaryAxesSubplot(ax=ax, scale=1.)

    tax.set_title(f"{title}", fontsize=pu.font_size())
    tax.boundary()
    tax.gridlines(multiple=.2, color="black")
    if info:
        tax.ticks(axis='lbr',
                  multiple=.2,
                  tick_formats="%.1f",
                  offset=0.05,
                  fontsize=pu.font_size())

        tax.left_axis_label(f"{probs.columns[2]}",
                            offset=.25,
                            fontsize=pu.font_size())
        tax.right_axis_label(f"{probs.columns[1]}",
                             offset=.25,
                             fontsize=pu.font_size())
        tax.bottom_axis_label(f"{probs.columns[0]}",
                              offset=.28,
                              fontsize=pu.font_size())
    # scats = []
    for i in subsys.unique():
        if pd.isna(i):
            tax.scatter(
                probs[subsys.isna()].values,
                color=color_dict[i],
                label='Not Classified',
                s=markersize,
            )
            # scats.append(scat)
        else:
            tax.scatter(
                probs[subsys.values == i].values,
                color=color_dict[i],
                zorder=10,
                label=i,
                s=markersize,
            )
            # scats.append(scat)
    if legend:
        tax.legend(loc=legend, fontsize=pu.label_size())

    tax.clear_matplotlib_ticks()
    tax.get_axes().axis('off')
    return fig
示例#2
0
    def draw(self,
             ax=None,
             setup=True,
             marker='o',
             color='k'):  # pragma: no cover
        """
        Plot the entropy triangle.

        Parameters
        ----------
        ax : Axis or None
            The matplotlib axis to plot on. If none is provided, one will be
            constructed.
        setup : bool
            If true, labels, tick marks, gridlines, and a boundary will be added
            to the plot. Defaults to True.
        marker : str
            The matplotlib marker shape to use.
        color : str
            The color of marker to use.
        """
        import ternary

        if ax is None:
            fig, ax = ternary.figure()
            fig.set_size_inches(10, 8)
        else:
            ax = ternary.TernaryAxesSubplot(ax=ax)

        if setup:
            ax.boundary()
            ax.gridlines(multiple=0.1)

            fontsize = 20
            ax.set_title("Entropy Triangle", fontsize=fontsize)
            ax.left_axis_label(self.left_label, fontsize=fontsize)
            ax.right_axis_label(self.right_label, fontsize=fontsize)
            ax.bottom_axis_label(self.bottom_label, fontsize=fontsize)

            ax.ticks(axis='lbr', multiple=0.1, linewidth=1)
            ax.clear_matplotlib_ticks()

        ax.scatter(self.points, marker=marker, color=color)
        ax._redraw_labels()

        return ax
def main(results, outfolder):
    algos_raw = list(results.keys())
    topos_raw = list(results.values().__iter__().__next__().keys())
    algos = list(
        sorted(algos_raw,
               key=lambda a: (listfind(ALGOS_PREFERRED_ORDER, a), a)))
    topos = list(
        sorted(topos_raw,
               key=lambda a: (listfind(TOPOS_PREFERRED_ORDER, a), a)))
    metrics = ['sl', 'sj', 'sb', 'pl', 'pj', 'pb', 'il', 'ij', 'ib', 'rd']
    db = {
        algo: {
            topo: {
                'sl': [
                    res['sequential']['ping']['avg']
                    for res in results[algo][topo]
                ],
                'sj': [
                    res['sequential']['ping']['mdev']
                    for res in results[algo][topo]
                ],
                'sb': [
                    avg(res['sequential']['iperfs'])
                    for res in results[algo][topo]
                ],
                'pl': [
                    res['parallel']['ping']['avg']
                    for res in results[algo][topo]
                ],
                'pj': [
                    res['parallel']['ping']['mdev']
                    for res in results[algo][topo]
                ],
                'pb': [
                    avg(res['parallel']['iperfs'])
                    for res in results[algo][topo]
                ],
                'rd': [res['routing_time'] for res in results[algo][topo]],
            }
            for topo in topos
        }
        for algo in algos
    }
    db = {
        algo: {
            topo: {
                **db[algo][topo],
                'il':
                list(
                    map(sub_itr, zip(db[algo][topo]['pl'],
                                     db[algo][topo]['sl']))),
                'ij':
                list(
                    map(sub_itr, zip(db[algo][topo]['pj'],
                                     db[algo][topo]['sj']))),
                'ib':
                list(
                    map(sub_itr, zip(db[algo][topo]['pb'],
                                     db[algo][topo]['sb']))),
            }
            for topo in topos
        }
        for algo in algos
    }
    db = {
        algo: {
            topo: db[algo][topo] if topo != 'all' else dict([
                (metric, flatten([db[algo][topo2][metric] for topo2 in topos]))
                for metric in metrics
            ])
            for topo in [*topos, 'all']
        }
        for algo in algos
    }
    topos = [*topos, 'all']
    db = {
        algo: {
            topo: db[algo][topo] if algo != 'all' else dict([
                (metric, flatten([db[algo2][topo][metric] for algo2 in algos]))
                for metric in metrics
            ])
            for topo in topos
        }
        for algo in [*algos, 'all']
    }
    algos = [*algos, 'all']
    avgdb = {
        algo: {
            topo: {persp: avg(db[algo][topo][persp])
                   for persp in metrics}
            for topo in topos
        }
        for algo in algos
    }
    avgdiffdb = {
        algo: {
            topo: {
                persp: avgdb[algo][topo][persp] - avgdb['all'][topo][persp]
                for persp in metrics
            }
            for topo in topos
        }
        for algo in algos if algo != 'all'
    }
    for metric in metrics:
        formatter = FMTS[metric]
        for prefix, tbl in [
            ('avg', build_table(avgdb, algos, topos, metric, formatter,
                                PHRASES)),
            ('avgdiff',
             build_table(avgdiffdb, [algo for algo in algos if algo != 'all'],
                         topos, metric, formatter, PHRASES))
        ]:
            outtexfile = Path(outfolder).joinpath(f'{prefix}_{metric}.tex')
            outcsvfile = Path(outfolder).joinpath(f'{prefix}_{metric}.csv')
            outmdfile = Path(outfolder).joinpath(f'{prefix}_{metric}.md')
            outtexfile.write_text(tbl2tex(tbl))
            outcsvfile.write_text(tbl2csv(tbl))
            outmdfile.write_text(tbl2md(tbl))
    longer_phrases = Phrases(PHRASES)
    phrases = Phrases({**PHRASES, **PHRASES_BOXPLOT_PATCH})
    for algo in algos:
        for metric in metrics:
            ylabel = FMTS_YLABEL[metric]
            yscalemultiplier = FMTS_YSCALEMULTIPLIER[metric]
            thistopos = [topo for topo in topos if topo != "all"]
            data = [
                sorted([y * yscalemultiplier for y in db[algo][topo][metric]],
                       reverse=True) for topo in thistopos
            ]
            bp = make_boxplot(
                data, [phrases[x] for x in thistopos], ylabel,
                f"{longer_phrases[metric]} using {longer_phrases[algo]}", True)
            bp.figure.savefig(outfolder.joinpath(f"am_{algo}_{metric}.pdf"),
                              bbox_inches='tight')
            plt.cla()
            plt.clf()
            plt.close()
            print(f"am_{algo}_{metric}.pdf")
    for topo in topos:
        for metric in metrics:
            ylabel = FMTS_YLABEL[metric]
            yscalemultiplier = FMTS_YSCALEMULTIPLIER[metric]
            thisalgos = [algo for algo in algos if algo != "all"]
            data = [
                sorted([y * yscalemultiplier for y in db[algo][topo][metric]],
                       reverse=True) for algo in thisalgos
            ]
            bp = make_boxplot(
                data, [phrases[x] for x in thisalgos], ylabel,
                f"{longer_phrases[metric]} on {longer_phrases[topo]}")
            bp.figure.savefig(outfolder.joinpath(f"tm_{topo}_{metric}.pdf"),
                              bbox_inches='tight')
            plt.cla()
            plt.clf()
            plt.close()
            print(f"tm_{topo}_{metric}.pdf")
    for flavour in 'ps':
        for topo in topos:
            _, ax = plt.subplots()
            ax.axis('off')
            tax = ternary.TernaryAxesSubplot(ax=ax, scale=1)
            tax.gridlines(multiple=0.1, color="gray")
            tax.boundary(linewidth=2.0)
            allthistopo = db['all'][topo]
            minl = min(allthistopo[flavour + 'l'])
            minj = min(allthistopo[flavour + 'j'])
            minb = min(allthistopo[flavour + 'b'])
            maxl = max(allthistopo[flavour + 'l'])
            maxj = max(allthistopo[flavour + 'j'])
            maxb = max(allthistopo[flavour + 'b'])
            dltl = maxl - minl
            dltj = maxj - minj
            dltb = maxb - minb
            dltl = dltl if dltl != 0 else 1
            dltj = dltj if dltj != 0 else 1
            dltb = dltb if dltb != 0 else 1
            for i, algo in enumerate(algos):
                if algo == 'all':
                    continue
                data = db[algo][topo]
                points = list(
                    zip(
                        [(i - minl) / dltl
                         for i in data[flavour + 'l']],  # top
                        [(i - minj) / dltj
                         for i in data[flavour + 'j']],  # right
                        [(i - minb) / dltb
                         for i in data[flavour + 'b']]  # left
                    ))
                tax.scatter(points,
                            marker=MARKERS[i],
                            color=COLORS[i],
                            label=longer_phrases[algo])
            # tax.ticks(axis='lbr', linewidth=1, multiple=0.2)
            tax.top_corner_label("High latency")
            tax.right_corner_label("High jitter")
            tax.left_corner_label("High bandwidth")
            tax.bottom_axis_label("Low latency")
            tax.left_axis_label("Low jitter")
            tax.right_axis_label("Low bandwidth")
            tax.legend()
            tax.savefig(outfolder.joinpath(f"tern_{flavour}_{topo}.pdf"))
            plt.cla()
            plt.clf()
            plt.close()
            print(f"tern_{flavour}_{topo}.pdf")
示例#4
0
# data can be plotted by entering data coords (rather than simplex coords):
points = [(70, 3, 27), (73, 2, 25), (68, 6, 26)]
points_c = tax.convert_coordinates(points, axisorder='brl')
tax.scatter(points_c, marker='o', s=25, c='r')

tax.ax.set_aspect('equal', adjustable='box')
tax._redraw_labels()

## Zoom example:
## Draw a plot with the full range on the left and a second plot which
## shows a zoomed region of the left plot.
fig = ternary.plt.figure(figsize=(11, 6))
ax1 = fig.add_subplot(2, 1, 1)
ax2 = fig.add_subplot(2, 1, 2)

tax1 = ternary.TernaryAxesSubplot(ax=ax1, scale=100)
tax1.boundary(linewidth=1.0)
tax1.gridlines(color="black", multiple=10, linewidth=0.5, ls='-')
tax1.ax.axis("equal")
tax1.ax.axis("off")

tax2 = ternary.TernaryAxesSubplot(ax=ax2, scale=30)
axes_colors = {'b': 'r', 'r': 'r', 'l': 'r'}
tax2.boundary(linewidth=1.0, axes_colors=axes_colors)
tax2.gridlines(color="r", multiple=5, linewidth=0.5, ls='-')
tax2.ax.axis("equal")
tax2.ax.axis("off")

fontsize = 16
tax1.set_title("Entire range")
tax1.left_axis_label("Logs", fontsize=fontsize, offset=0.12)
示例#5
0
def feature_ternary_heatmap(scale,
                            feature_name,
                            featurizer=None,
                            use_X=None,
                            style='triangular',
                            labelsize=11,
                            add_labeloffset=0,
                            cmap=None,
                            ax=None,
                            figsize=None,
                            vlim=None,
                            multiple=0.1,
                            tick_kwargs={
                                'tick_formats': '%.1f',
                                'offset': 0.02
                            },
                            tern_axes=['Ca', 'Al', 'Ba'],
                            tern_labels=['CaO', 'Al$_2$O$_3$', 'BaO']):
    """
    
    """
    if use_X is None:
        coords, X = featurize_simplex(scale,
                                      featurizer,
                                      feature_cols=featurizer.feature_labels(),
                                      tern_axes=tern_axes)
        X = pd.DataFrame(X, columns=featurizer.feature_labels())
    else:
        coords = [tup for tup in simplex_iterator(scale)]
        X = use_X

    y = X.loc[:, feature_name]

    if vlim is None:
        vmin = min(y)
        vmax = max(y)
    else:
        vmin, vmax = vlim

    points = dict(zip([c[0:2] for c in coords], y))

    if ax == None:
        fig, ax = plt.subplots(figsize=figsize)
        tfig, tax = ternary.figure(scale=scale, ax=ax)
    else:
        tax = ternary.TernaryAxesSubplot(scale=scale, ax=ax)

    tax.heatmap(points,
                style=style,
                colorbar=False,
                cmap=cmap,
                vmin=vmin,
                vmax=vmax)
    #rescale_ticks(tax,new_scale=axis_scale,multiple = multiple, **tick_kwargs)
    tax.boundary()
    tax.ax.axis('off')

    tax.right_corner_label(tern_labels[0],
                           fontsize=labelsize,
                           va='center',
                           offset=0.08 + add_labeloffset)
    tax.top_corner_label(tern_labels[1],
                         fontsize=labelsize,
                         va='center',
                         offset=0.05 + add_labeloffset)
    tax.left_corner_label(tern_labels[2],
                          fontsize=labelsize,
                          va='center',
                          offset=0.08 + add_labeloffset)

    tax._redraw_labels()

    return tax
示例#6
0
def estimator_ternary_heatmap(scale,
                              estimator,
                              featurizer=None,
                              feature_cols=None,
                              scaler=None,
                              use_X=None,
                              style='triangular',
                              labelsize=11,
                              add_labeloffset=0,
                              cmap=None,
                              ax=None,
                              figsize=None,
                              vlim=None,
                              metric='median',
                              multiple=0.1,
                              tick_kwargs={
                                  'tick_formats': '%.1f',
                                  'offset': 0.02
                              },
                              tern_axes=['Ca', 'Al', 'Ba'],
                              tern_labels=['CaO', 'Al$_2$O$_3$', 'BaO']):
    """
	Generate ternary heatmap of ML predictions
	
	Args:
		scale: simplex scale
		estimator: sklearn estimator instance
		featurizer: featurizer instance
		feature_cols: subset of feature names used in model_eval
		scaler: sklearn scaler instance
		use_X: pre-calculated feature matrix; if passed, featurizer, feature_cols, and scaler are ignored
		style: heatmap interpolation style
		tern_axes: ternary axes. Only used for generating simplex compositions; ignored if use_X is supplied. Defaults to ['Ca','Al','Ba']
		metric: if 'median', return point estimate. If 'iqr', return IQR of prediction
	"""
    coords, y = predict_simplex(estimator, scale, featurizer, feature_cols,
                                scaler, use_X, tern_axes, metric)

    if vlim is None:
        vmin = min(y)
        vmax = max(y)
    else:
        vmin, vmax = vlim

    points = dict(zip([c[0:2] for c in coords], y))

    if ax == None:
        fig, ax = plt.subplots(figsize=figsize)
        tfig, tax = ternary.figure(scale=scale, ax=ax)
    else:
        tax = ternary.TernaryAxesSubplot(scale=scale, ax=ax)

    tax.heatmap(points,
                style=style,
                colorbar=False,
                cmap=cmap,
                vmin=vmin,
                vmax=vmax)
    #rescale_ticks(tax,new_scale=axis_scale,multiple = multiple, **tick_kwargs)
    tax.boundary()
    tax.ax.axis('off')

    tax.right_corner_label(tern_labels[0],
                           fontsize=labelsize,
                           va='center',
                           offset=0.08 + add_labeloffset)
    tax.top_corner_label(tern_labels[1],
                         fontsize=labelsize,
                         va='center',
                         offset=0.05 + add_labeloffset)
    tax.left_corner_label(tern_labels[2],
                          fontsize=labelsize,
                          va='center',
                          offset=0.08 + add_labeloffset)

    tax._redraw_labels()

    return tax
示例#7
0
def plot_labeled_ternary(comps,
                         values,
                         ax=None,
                         label_points=True,
                         add_labeloffset=0,
                         corner_labelsize=12,
                         point_labelsize=11,
                         point_labeloffset=[0, 0.01, 0],
                         cmap=None,
                         vlim=None,
                         **scatter_kw):
    tern_axes = ['Ca', 'Al', 'Ba']

    if ax is None:
        fig, ax = plt.subplots(figsize=(9, 8))
    else:
        fig = ax.get_figure()

    #tfig, tax = ternary.figure(scale=1,ax=ax)
    tax = ternary.TernaryAxesSubplot(scale=1, ax=ax)

    points = [get_coords_from_comp(c, tern_axes) for c in comps]

    if vlim is None:
        vmin = min(values)
        vmax = max(values)
    else:
        vmin, vmax = vlim
    tax.scatter(points,
                c=values,
                cmap=cmap,
                vmin=vmin,
                vmax=vmax,
                **scatter_kw)

    tern_labels = ['CaO', 'Al$_2$O$_3$', 'BaO']

    tax.right_corner_label(tern_labels[0],
                           fontsize=corner_labelsize,
                           va='center',
                           offset=0.08 + add_labeloffset)
    tax.top_corner_label(tern_labels[1],
                         fontsize=corner_labelsize,
                         va='center',
                         offset=0.05 + add_labeloffset)
    tax.left_corner_label(tern_labels[2],
                          fontsize=corner_labelsize,
                          va='center',
                          offset=0.08 + add_labeloffset)

    tax.boundary(linewidth=1)
    #tax.clear_matplotlib_ticks()
    ax.axis('off')

    if label_points == True:
        for p, val in zip(points, values):
            if pd.isnull(val):
                disp = 'NA'
            else:
                disp = '{}'.format(int(round(val, 0)))
            tax.annotate(disp,
                         p + point_labeloffset,
                         size=point_labelsize,
                         ha='center',
                         va='bottom')

    #add_colorbar(fig,label='NH3 Production Rate (mmol/g$\cdot$h)',vmin=min(values),vmax=max(values),cbrect=[0.9,0.2,0.03,0.67])
    tax._redraw_labels()

    return tax
    def plot(self, selected_type_1, selected_type_2, selected_operation,
             min_color_scale, max_color_scale, is_percentage):

        #get config
        if self._model.config_data is not None:
            config = self._model.config_data
        else:
            config = self.default_config

        data = self._model.ternary_file_data
        # perform calculation
        data, title = self.calculate(data, selected_type_1, selected_type_2,
                                     selected_operation, is_percentage)
        # remove the inf and nan
        inf_nan_indexes = data.index[data['calculated'].isin(
            [np.nan, np.inf, -np.inf])].tolist()
        inf_indexes = data[data['calculated'].isin([np.inf, -np.inf])]
        data = data.drop(inf_nan_indexes)

        # get default min and max color scale if values are not defined
        if min_color_scale is None:
            min_color_scale = min(data["calculated"].values)
        else:
            min_color_scale = float(min_color_scale)
        if max_color_scale is None:
            max_color_scale = max(data["calculated"].values)
        else:
            max_color_scale = float(max_color_scale)

        # remove/replace data that is out of range
        # remove
        # data = data.loc[data.loc[:, 'calculated'] >= min_color_scale, :]
        # data = data.loc[data.loc[:, 'calculated'] <= max_color_scale, :]
        # replace
        data['calculated'].values[
            data['calculated'].values < min_color_scale] = min_color_scale
        data['calculated'].values[
            data['calculated'].values > max_color_scale] = max_color_scale

        points = np.array([data["x"].values, data["y"].values]).transpose()
        # colors map
        cm = LinearSegmentedColormap.from_list('Capacities',
                                               config["colors"],
                                               N=1024)

        # normalize data
        norm = Normalize(vmin=min_color_scale, vmax=max_color_scale)

        # get color based on color map
        data["calculated_norm"] = data["calculated"].apply(lambda x: norm(x))
        data["calculated_color"] = data["calculated_norm"].apply(
            lambda x: cm(x))
        colors = data["calculated_color"].values

        # Creates a ternary set of axes to plot the diagram from python-ternary
        fig, ax = plt.subplots(**config["figure"])
        # fix aspect ratio.
        ax.set_aspect("equal")
        ax.set_title("Figure : {} | {}".format(plt.gcf().number, title),
                     **config["title"])
        tax = ternary.TernaryAxesSubplot(ax=ax, scale=1.0)
        tax.boundary()
        tax.gridlines(**config["gridlines"])
        tax.scatter(points,
                    c=colors,
                    colormap=cm,
                    vmin=min_color_scale,
                    vmax=max_color_scale,
                    **config["scatter"])
        # colorbar
        cbar_ax = fig.axes[-1]
        cbar_ax.tick_params(**config["colorbar_tick_params"])
        # ticks
        tax.ticks(**config["axis_ticks"])
        # set axis labels
        tax.left_axis_label("1 - x - y", **config["axis_label"])
        tax.right_axis_label("y", **config["axis_label"])
        tax.bottom_axis_label("x", **config["axis_label"])
        plt.axis('off')
        # if there are some inf index , so report it to user
        if not inf_indexes.empty:
            tooltip = inf_indexes[['x', 'y']].to_string()
            task_bar_data = {
                "color":
                "orange",
                'message':
                "Figure : {} | Warning: calculation contains inf | {}".format(
                    plt.gcf().number, title),
                'tooltip':
                tooltip
            }
            self.task_bar_message.emit(task_bar_data)
        else:
            self.task_bar_message.emit({
                "color":
                "green",
                "message":
                "Figure : {} | {}".format(plt.gcf().number, title),
                "tooltip":
                "{} rows were removed: {}".format(
                    str(len(inf_nan_indexes)),
                    ",".join(str(x) for x in inf_nan_indexes))
            })
        tax.show()
        tax.close()
示例#9
0
def quat_slice_heatmap2(tuple_scale,
                        zfunc,
                        slice_val,
                        zfunc_kwargs={},
                        style='triangular',
                        slice_axis='Y',
                        tern_axes=['Co', 'Fe', 'Zr'],
                        labelsize=14,
                        add_labeloffset=0,
                        cmap=plt.cm.viridis,
                        ax=None,
                        figsize=None,
                        vmin=None,
                        vmax=None,
                        Ba=1,
                        multiple=0.1,
                        tick_kwargs={
                            'tick_formats': '%.1f',
                            'offset': 0.02
                        }):
    """
	get zvals from formula instead of tup
	"""
    axis_scale = 1 - slice_val
    tuples = []
    zvals = []
    for tup in simplex_iterator(scale=tuple_scale):
        tuples.append(tup)
        formula = sliceformula_from_tuple(tup,
                                          slice_val=slice_val,
                                          slice_axis=slice_axis,
                                          tern_axes=tern_axes,
                                          Ba=Ba)
        zvals.append(zfunc(formula, **zfunc_kwargs))
    if vmin == None:
        vmin = min(zvals)
    if vmax == None:
        vmax = max(zvals)

    d = dict(zip([t[0:2] for t in tuples], zvals))

    if ax == None:
        fig, ax = plt.subplots(figsize=figsize)
        tfig, tax = ternary.figure(scale=tuple_scale, ax=ax)
    else:
        tax = ternary.TernaryAxesSubplot(scale=tuple_scale, ax=ax)

    tax.heatmap(d,
                style=style,
                colorbar=False,
                cmap=cmap,
                vmin=vmin,
                vmax=vmax)
    rescale_ticks(tax, new_scale=axis_scale, multiple=multiple, **tick_kwargs)
    tax.boundary()
    tax.ax.axis('off')

    tax.right_corner_label(tern_axes[0],
                           fontsize=labelsize,
                           offset=0.08 + add_labeloffset)
    tax.top_corner_label(tern_axes[1],
                         fontsize=labelsize,
                         offset=0.2 + add_labeloffset)
    tax.left_corner_label(tern_axes[2],
                          fontsize=labelsize,
                          offset=0.08 + add_labeloffset)

    tax._redraw_labels()

    return tax, vmin, vmax
示例#10
0
def quat_slice_scatter(data,
                       z,
                       slice_start,
                       slice_width=0,
                       slice_axis='Y',
                       tern_axes=['Co', 'Fe', 'Zr'],
                       tern_labels=None,
                       labelsize=14,
                       tick_kwargs={
                           'axis': 'lbr',
                           'linewidth': 1,
                           'tick_formats': '%.1f',
                           'offset': 0.03
                       },
                       nticks=5,
                       add_labeloffset=0,
                       cmap=plt.cm.viridis,
                       ax=None,
                       vmin=None,
                       vmax=None,
                       ptsize=8,
                       figsize=None,
                       scatter_kw={}):
    if slice_width == 0:
        df = data[data[slice_axis] == slice_start]
    else:
        df = data[(data[slice_axis] >= slice_start)
                  & (data[slice_axis] < slice_start + slice_width)]

    points = df.loc[:, tern_axes].values
    colors = df.loc[:, z].values
    if len(df[pd.isnull(df[z])]) > 0:
        print('Warning: null values in z column')
    if vmin == None:
        vmin = np.min(colors)
    if vmax == None:
        vmax = np.max(colors)

    scale = 1 - (slice_start + slice_width / 2
                 )  #point coords must sum to scale
    print('Scale: {}'.format(scale))
    #since each comp has different slice_axis value, need to scale points to plot scale
    ptsum = np.sum(points, axis=1)[np.newaxis].T
    scaled_pts = points * scale / ptsum

    if ax == None:
        fig, ax = plt.subplots(figsize=figsize)
        tfig, tax = ternary.figure(scale=scale, ax=ax)
    else:
        tax = ternary.TernaryAxesSubplot(scale=scale, ax=ax)

    if len(points) > 0:
        tax.scatter(scaled_pts,
                    s=ptsize,
                    cmap=cmap,
                    vmin=vmin,
                    vmax=vmax,
                    colorbar=False,
                    c=colors,
                    **scatter_kw)

    tax.boundary(linewidth=1.0)
    tax.clear_matplotlib_ticks()

    multiple = scale / nticks
    #manually set ticks and locations - default tax.ticks behavior does not work
    ticks = list(np.arange(0, scale + 1e-6, multiple))
    locations = ticks

    tax.ticks(multiple=multiple,
              ticks=ticks,
              locations=locations,
              **tick_kwargs)
    tax.gridlines(multiple=multiple, linewidth=0.8)

    if tern_labels is None:
        tern_labels = tern_axes
    tax.right_corner_label(tern_labels[0],
                           fontsize=labelsize,
                           offset=0.08 + add_labeloffset)
    tax.top_corner_label(tern_labels[1],
                         fontsize=labelsize,
                         offset=0.2 + add_labeloffset)
    tax.left_corner_label(tern_labels[2],
                          fontsize=labelsize,
                          offset=0.08 + add_labeloffset)

    tax._redraw_labels()
    ax.axis('off')

    return tax, vmin, vmax
示例#11
0
def run_causal_simulations():
    t = 0.8  # transmission rate
    b = 0.0  # background rate

    active_learning_problems = utils.create_active_learning_hyp_space(t=t, b=b)
    ig_model_predictions = []
    self_teaching_model_predictions = []
    pts_model_predictions = []

    # get predictions of all three models
    for i, active_learning_problem in enumerate(active_learning_problems):
        gal = GraphActiveLearner(active_learning_problem)
        gal.update_posterior()
        eig = gal.expected_information_gain().tolist()
        ig_model_predictions.append(eig)

        gst = GraphSelfTeacher(active_learning_problem)
        gst.update_learner_posterior()
        self_teaching_posterior = gst.update_self_teaching_posterior()
        self_teaching_model_predictions.append(self_teaching_posterior)

        gpts = GraphPositiveTestStrategy(active_learning_problem)
        pts_model_predictions.append(gpts.positive_test_strategy())

    figure, ax = plt.subplots()
    figure.set_size_inches(16, 5)

    ax.set_frame_on(False)
    ax.set_xticks([])
    ax.set_yticks([])

    points_one = [(0.33, 0.33, 0.33), (0.5, 0.5, 0)]
    points_two = [(0.33, 0.33, 0.33), (0.5, 0, 0.5)]
    points_three = [(0.33, 0.33, 0.33), (0, 0.5, 0.5)]

    for i in range(len(ig_model_predictions)):
        # make ternary plot
        ax = figure.add_subplot(3, 9, i + 1)
        tax = ternary.TernaryAxesSubplot(ax=ax)
        tax.set_title("Problem {}".format(i + 1), fontsize=10)
        tax.boundary(linewidth=2.0)
        tax.scatter([ig_model_predictions[i]],
                    marker='o',
                    color='red',
                    label="Information Gain",
                    alpha=0.8,
                    s=40)
        tax.scatter([pts_model_predictions[i]],
                    marker='d',
                    color='green',
                    label="Positive-Test Strategy",
                    alpha=0.8,
                    s=40)
        tax.scatter([self_teaching_model_predictions[i]],
                    marker='s',
                    color='blue',
                    label="Self-Teaching",
                    alpha=0.8,
                    s=40)

        tax.line(points_one[0], points_one[1], color='black', linestyle=':')
        tax.line(points_two[0], points_two[1], color='black', linestyle=':')
        tax.line(points_three[0],
                 points_three[1],
                 color='black',
                 linestyle=':')
        tax.left_axis_label("x_1", fontsize=20)
        tax.right_axis_label("x_2", fontsize=20)
        tax.bottom_axis_label("x_3", fontsize=20)

        tax.clear_matplotlib_ticks()
        ax.set_frame_on(False)
        handles, labels = ax.get_legend_handles_labels()

    plt.savefig('figures/causal_learning_simulations.png', dpi=600)
示例#12
0
def plot(d,
         t,
         r,
         ax=None,
         title=False,
         scale=20,
         fontsize=20,
         blw=1.0,
         msz=8,
         small=False):
    def entropy(p):
        return sharma_mittal.sm_entropy(p, t, r)

    if ax == None: f, ax = plt.subplots()

    tax = ternary.TernaryAxesSubplot(ax=ax, scale=scale)

    tax.gridlines(color="white", multiple=1, linestyle="-", linewidth=0.1)
    # tax.gridlines(color="white", multiple=2,
    # 	linestyle="-", linewidth=0.2)

    numTicks = 10
    scaleMtpl = int(scale / numTicks)
    tickLocs = np.append(np.arange(0, scale, scaleMtpl), [scale])
    ticks = [round(tick, 2) for tick in np.linspace(0, 1.0, numTicks + 1)]

    tax.ticks(ticks=ticks,
              axis='lbr',
              axes_colors={
                  'l': 'grey',
                  'r': 'grey',
                  'b': 'grey'
              },
              offset=0.02,
              linewidth=1,
              locations=list(tickLocs),
              clockwise=True)
    tax.clear_matplotlib_ticks()
    # ax = plt.gca()
    ax.axis('off')

    tax.left_axis_label(r"$P(k_2) \quad \rightarrow$",
                        fontsize=fontsize,
                        offset=0.12)
    tax.right_axis_label(r"$P(k_1) \quad \rightarrow$",
                         fontsize=fontsize,
                         offset=0.12)
    tax.bottom_axis_label(r"$\leftarrow \quad P(k_3)$",
                          fontsize=fontsize,
                          offset=0.10)
    tax._redraw_labels()

    cmap = sns.cubehelix_palette(light=1,
                                 dark=0,
                                 start=2.5,
                                 rot=0,
                                 as_cmap=True,
                                 reverse=True)
    tax.heatmapf(entropy, boundary=True, style="triangular", cmap=cmap)
    tax.boundary(linewidth=blw)

    q_color = ['lime', 'fuchsia']
    for i, q in enumerate(d.posterior):
        tax.line(q[0] * scale,
                 q[1] * scale,
                 linewidth=1.,
                 marker=['o', 'o'][i],
                 color='k',
                 linestyle="-",
                 markersize=msz,
                 markeredgecolor='k',
                 markerfacecolor=q_color[i],
                 markeredgewidth=1.,
                 label=r'$P(K|%s)$' % ['Q_1', 'Q_2'][i])

    tax.line(d.prior * scale,
             d.prior * scale,
             linewidth=1.,
             marker='s',
             color='k',
             linestyle="-",
             markersize=msz,
             markeredgecolor='k',
             markerfacecolor='yellow',
             markeredgewidth=1.,
             label=r'$P(K)$')

    if not title == False:
        ax.text(0.5,
                1.05,
                title,
                ha='center',
                va='center',
                transform=ax.transAxes)

    #Create Legend
    if not small:
        ax.text(0.00,
                0.95,
                r'Order $\;\;(r) = %.2f$' % r,
                ha='left',
                va='center',
                transform=ax.transAxes)
        ax.text(0.00,
                0.9,
                r'Degree $(t) = %.2f$' % t,
                ha='left',
                va='center',
                transform=ax.transAxes)

    if small:
        tax.legend(bbox_to_anchor=(0., 0.00, 1., -.102),
                   ncol=3,
                   mode="expand",
                   borderaxespad=0.,
                   handletextpad=0,
                   fancybox=True,
                   shadow=False,
                   frameon=True)
    else:
        tax.legend()

    #Remove Colorbar
    f = plt.gcf()
    f.delaxes(f.axes[-1])
    plt.draw()
    return ax


# d = design.design(size=(3,2))
# plot(d,t=1,r=1)
# plt.show()
示例#13
0
y0 = stiff_data[:, 31:34]
y1 = stiff_data[:, 34:37]
y2 = stiff_data[:, 37:40]
y3 = stiff_data[:, 40:43]
y4 = stiff_data[:, 43:46]

x0 = nonstiff_data[:, 31:34]
x1 = nonstiff_data[:, 34:37]
x2 = nonstiff_data[:, 37:40]
x3 = nonstiff_data[:, 40:43]
x4 = nonstiff_data[:, 43:46]

mark_val = len(x4[:, 0])
figure, ax = plt.subplots(1, 2, figsize=(12, 6))
tax1 = ternary.TernaryAxesSubplot(ax=ax[0])
eq_point = get_coordinates(y4[-1, :])
st_point = []
st_point.append(get_coordinates(y0[0, :]))
st_point.append(get_coordinates(y1[0, :]))
st_point.append(get_coordinates(y2[0, :]))
st_point.append(get_coordinates(y3[0, :]))
st_point.append(get_coordinates(y4[0, :]))

tax1.boundary()
tax1.gridlines(multiple=0.1, color="black")
tax1.set_title("(a) Stiff Trajectories", fontsize=14)
fontsize = 12
offset = 0.105
tax1.left_axis_label("$H^+$", fontsize=fontsize, offset=offset)
tax1.right_axis_label("$H^*$", fontsize=fontsize, offset=offset)
示例#14
0
def format_tern_ax(
    ax=None,
    labels=None,
    labelsides=False,
    title=None,
    ticks=False,
    grid=True,
    removespines=True,
    scale=100,
    fontsize="x-large",
    ticks_kws={
        "axis": "lbr",
        "linewidth": 1,
        "multiple": 10
    },
    bound_kws={
        "linewidth": 1,
        "zorder": 0.6
    },
    grid_kws={
        "multiple": 10,
        "color": "grey",
        "linewidth": 1,
        "zorder": 0.5
    },
):
    """
    Wrapper to return a ternary ax with optional grid, title, labels.
    *ax* is either None (creates a new fig and axes) or a matplotlib
    axes object.
    
    *labels* is a three-element array of strings in the order left, top,
    right for labelling ternary plot tips, or if *labelsides=True*,
    ternary plot sides. *scale* is an integer of subdivisions for
    ternary axes; *fontsize* controls the size of label and title text.
    
    *ticks_kws* is a dictionary of parameters for tax.ticks(),
    *bound_kws* for tax.boundary(), *grid_kws* for tax.gridlines().
    
    *labels*, *title*, *ticks*, *grid* can be set to False/None to turn
    off each respective item. *removespines* removes the regular
    matplotlib spines from the axes.
    """
    # Initiate ternary plot
    if ax:
        tax = ternary.TernaryAxesSubplot(ax=ax, scale=scale)
    else:
        fig, tax = ternary.figure(scale=scale)

    # Draw Boundary and Gridlines
    tax.boundary(**bound_kws)
    if grid:
        tax.gridlines(**grid_kws)

    # Set Axis labels and Title
    if title:
        tax.set_title(title, fontsize=fontsize)
    if labels:
        if labelsides:
            tax.left_axis_label(labels[0], fontsize=fontsize)
            tax.top_axis_label(labels[1], fontsize=fontsize)
            tax.right_axis_label(labels[2], fontsize=fontsize)
        else:
            tax.left_corner_label(labels[0], fontsize=fontsize)
            # offset makes it look more symmetric
            tax.top_corner_label(labels[1], fontsize=fontsize, offset=0.18)
            tax.right_corner_label(labels[2], fontsize=fontsize)

    # Set ticks
    if ticks:
        ternary_ax.ticks(**ticks_kws)

    # Remove default Matplotlib Axes
    if removespines:
        tax.get_axes().axis("off")

    return tax