def plot_tern(probs, subsys, title, color_dict, ax=None, legend=False, info=True, markersize=10): if ax is None: fig, ax = plt.subplots(figsize=(4.5, 4)) else: fig = plt.gcf() ax.set_aspect('equal') tax = ternary.TernaryAxesSubplot(ax=ax, scale=1.) tax.set_title(f"{title}", fontsize=pu.font_size()) tax.boundary() tax.gridlines(multiple=.2, color="black") if info: tax.ticks(axis='lbr', multiple=.2, tick_formats="%.1f", offset=0.05, fontsize=pu.font_size()) tax.left_axis_label(f"{probs.columns[2]}", offset=.25, fontsize=pu.font_size()) tax.right_axis_label(f"{probs.columns[1]}", offset=.25, fontsize=pu.font_size()) tax.bottom_axis_label(f"{probs.columns[0]}", offset=.28, fontsize=pu.font_size()) # scats = [] for i in subsys.unique(): if pd.isna(i): tax.scatter( probs[subsys.isna()].values, color=color_dict[i], label='Not Classified', s=markersize, ) # scats.append(scat) else: tax.scatter( probs[subsys.values == i].values, color=color_dict[i], zorder=10, label=i, s=markersize, ) # scats.append(scat) if legend: tax.legend(loc=legend, fontsize=pu.label_size()) tax.clear_matplotlib_ticks() tax.get_axes().axis('off') return fig
def draw(self, ax=None, setup=True, marker='o', color='k'): # pragma: no cover """ Plot the entropy triangle. Parameters ---------- ax : Axis or None The matplotlib axis to plot on. If none is provided, one will be constructed. setup : bool If true, labels, tick marks, gridlines, and a boundary will be added to the plot. Defaults to True. marker : str The matplotlib marker shape to use. color : str The color of marker to use. """ import ternary if ax is None: fig, ax = ternary.figure() fig.set_size_inches(10, 8) else: ax = ternary.TernaryAxesSubplot(ax=ax) if setup: ax.boundary() ax.gridlines(multiple=0.1) fontsize = 20 ax.set_title("Entropy Triangle", fontsize=fontsize) ax.left_axis_label(self.left_label, fontsize=fontsize) ax.right_axis_label(self.right_label, fontsize=fontsize) ax.bottom_axis_label(self.bottom_label, fontsize=fontsize) ax.ticks(axis='lbr', multiple=0.1, linewidth=1) ax.clear_matplotlib_ticks() ax.scatter(self.points, marker=marker, color=color) ax._redraw_labels() return ax
def main(results, outfolder): algos_raw = list(results.keys()) topos_raw = list(results.values().__iter__().__next__().keys()) algos = list( sorted(algos_raw, key=lambda a: (listfind(ALGOS_PREFERRED_ORDER, a), a))) topos = list( sorted(topos_raw, key=lambda a: (listfind(TOPOS_PREFERRED_ORDER, a), a))) metrics = ['sl', 'sj', 'sb', 'pl', 'pj', 'pb', 'il', 'ij', 'ib', 'rd'] db = { algo: { topo: { 'sl': [ res['sequential']['ping']['avg'] for res in results[algo][topo] ], 'sj': [ res['sequential']['ping']['mdev'] for res in results[algo][topo] ], 'sb': [ avg(res['sequential']['iperfs']) for res in results[algo][topo] ], 'pl': [ res['parallel']['ping']['avg'] for res in results[algo][topo] ], 'pj': [ res['parallel']['ping']['mdev'] for res in results[algo][topo] ], 'pb': [ avg(res['parallel']['iperfs']) for res in results[algo][topo] ], 'rd': [res['routing_time'] for res in results[algo][topo]], } for topo in topos } for algo in algos } db = { algo: { topo: { **db[algo][topo], 'il': list( map(sub_itr, zip(db[algo][topo]['pl'], db[algo][topo]['sl']))), 'ij': list( map(sub_itr, zip(db[algo][topo]['pj'], db[algo][topo]['sj']))), 'ib': list( map(sub_itr, zip(db[algo][topo]['pb'], db[algo][topo]['sb']))), } for topo in topos } for algo in algos } db = { algo: { topo: db[algo][topo] if topo != 'all' else dict([ (metric, flatten([db[algo][topo2][metric] for topo2 in topos])) for metric in metrics ]) for topo in [*topos, 'all'] } for algo in algos } topos = [*topos, 'all'] db = { algo: { topo: db[algo][topo] if algo != 'all' else dict([ (metric, flatten([db[algo2][topo][metric] for algo2 in algos])) for metric in metrics ]) for topo in topos } for algo in [*algos, 'all'] } algos = [*algos, 'all'] avgdb = { algo: { topo: {persp: avg(db[algo][topo][persp]) for persp in metrics} for topo in topos } for algo in algos } avgdiffdb = { algo: { topo: { persp: avgdb[algo][topo][persp] - avgdb['all'][topo][persp] for persp in metrics } for topo in topos } for algo in algos if algo != 'all' } for metric in metrics: formatter = FMTS[metric] for prefix, tbl in [ ('avg', build_table(avgdb, algos, topos, metric, formatter, PHRASES)), ('avgdiff', build_table(avgdiffdb, [algo for algo in algos if algo != 'all'], topos, metric, formatter, PHRASES)) ]: outtexfile = Path(outfolder).joinpath(f'{prefix}_{metric}.tex') outcsvfile = Path(outfolder).joinpath(f'{prefix}_{metric}.csv') outmdfile = Path(outfolder).joinpath(f'{prefix}_{metric}.md') outtexfile.write_text(tbl2tex(tbl)) outcsvfile.write_text(tbl2csv(tbl)) outmdfile.write_text(tbl2md(tbl)) longer_phrases = Phrases(PHRASES) phrases = Phrases({**PHRASES, **PHRASES_BOXPLOT_PATCH}) for algo in algos: for metric in metrics: ylabel = FMTS_YLABEL[metric] yscalemultiplier = FMTS_YSCALEMULTIPLIER[metric] thistopos = [topo for topo in topos if topo != "all"] data = [ sorted([y * yscalemultiplier for y in db[algo][topo][metric]], reverse=True) for topo in thistopos ] bp = make_boxplot( data, [phrases[x] for x in thistopos], ylabel, f"{longer_phrases[metric]} using {longer_phrases[algo]}", True) bp.figure.savefig(outfolder.joinpath(f"am_{algo}_{metric}.pdf"), bbox_inches='tight') plt.cla() plt.clf() plt.close() print(f"am_{algo}_{metric}.pdf") for topo in topos: for metric in metrics: ylabel = FMTS_YLABEL[metric] yscalemultiplier = FMTS_YSCALEMULTIPLIER[metric] thisalgos = [algo for algo in algos if algo != "all"] data = [ sorted([y * yscalemultiplier for y in db[algo][topo][metric]], reverse=True) for algo in thisalgos ] bp = make_boxplot( data, [phrases[x] for x in thisalgos], ylabel, f"{longer_phrases[metric]} on {longer_phrases[topo]}") bp.figure.savefig(outfolder.joinpath(f"tm_{topo}_{metric}.pdf"), bbox_inches='tight') plt.cla() plt.clf() plt.close() print(f"tm_{topo}_{metric}.pdf") for flavour in 'ps': for topo in topos: _, ax = plt.subplots() ax.axis('off') tax = ternary.TernaryAxesSubplot(ax=ax, scale=1) tax.gridlines(multiple=0.1, color="gray") tax.boundary(linewidth=2.0) allthistopo = db['all'][topo] minl = min(allthistopo[flavour + 'l']) minj = min(allthistopo[flavour + 'j']) minb = min(allthistopo[flavour + 'b']) maxl = max(allthistopo[flavour + 'l']) maxj = max(allthistopo[flavour + 'j']) maxb = max(allthistopo[flavour + 'b']) dltl = maxl - minl dltj = maxj - minj dltb = maxb - minb dltl = dltl if dltl != 0 else 1 dltj = dltj if dltj != 0 else 1 dltb = dltb if dltb != 0 else 1 for i, algo in enumerate(algos): if algo == 'all': continue data = db[algo][topo] points = list( zip( [(i - minl) / dltl for i in data[flavour + 'l']], # top [(i - minj) / dltj for i in data[flavour + 'j']], # right [(i - minb) / dltb for i in data[flavour + 'b']] # left )) tax.scatter(points, marker=MARKERS[i], color=COLORS[i], label=longer_phrases[algo]) # tax.ticks(axis='lbr', linewidth=1, multiple=0.2) tax.top_corner_label("High latency") tax.right_corner_label("High jitter") tax.left_corner_label("High bandwidth") tax.bottom_axis_label("Low latency") tax.left_axis_label("Low jitter") tax.right_axis_label("Low bandwidth") tax.legend() tax.savefig(outfolder.joinpath(f"tern_{flavour}_{topo}.pdf")) plt.cla() plt.clf() plt.close() print(f"tern_{flavour}_{topo}.pdf")
# data can be plotted by entering data coords (rather than simplex coords): points = [(70, 3, 27), (73, 2, 25), (68, 6, 26)] points_c = tax.convert_coordinates(points, axisorder='brl') tax.scatter(points_c, marker='o', s=25, c='r') tax.ax.set_aspect('equal', adjustable='box') tax._redraw_labels() ## Zoom example: ## Draw a plot with the full range on the left and a second plot which ## shows a zoomed region of the left plot. fig = ternary.plt.figure(figsize=(11, 6)) ax1 = fig.add_subplot(2, 1, 1) ax2 = fig.add_subplot(2, 1, 2) tax1 = ternary.TernaryAxesSubplot(ax=ax1, scale=100) tax1.boundary(linewidth=1.0) tax1.gridlines(color="black", multiple=10, linewidth=0.5, ls='-') tax1.ax.axis("equal") tax1.ax.axis("off") tax2 = ternary.TernaryAxesSubplot(ax=ax2, scale=30) axes_colors = {'b': 'r', 'r': 'r', 'l': 'r'} tax2.boundary(linewidth=1.0, axes_colors=axes_colors) tax2.gridlines(color="r", multiple=5, linewidth=0.5, ls='-') tax2.ax.axis("equal") tax2.ax.axis("off") fontsize = 16 tax1.set_title("Entire range") tax1.left_axis_label("Logs", fontsize=fontsize, offset=0.12)
def feature_ternary_heatmap(scale, feature_name, featurizer=None, use_X=None, style='triangular', labelsize=11, add_labeloffset=0, cmap=None, ax=None, figsize=None, vlim=None, multiple=0.1, tick_kwargs={ 'tick_formats': '%.1f', 'offset': 0.02 }, tern_axes=['Ca', 'Al', 'Ba'], tern_labels=['CaO', 'Al$_2$O$_3$', 'BaO']): """ """ if use_X is None: coords, X = featurize_simplex(scale, featurizer, feature_cols=featurizer.feature_labels(), tern_axes=tern_axes) X = pd.DataFrame(X, columns=featurizer.feature_labels()) else: coords = [tup for tup in simplex_iterator(scale)] X = use_X y = X.loc[:, feature_name] if vlim is None: vmin = min(y) vmax = max(y) else: vmin, vmax = vlim points = dict(zip([c[0:2] for c in coords], y)) if ax == None: fig, ax = plt.subplots(figsize=figsize) tfig, tax = ternary.figure(scale=scale, ax=ax) else: tax = ternary.TernaryAxesSubplot(scale=scale, ax=ax) tax.heatmap(points, style=style, colorbar=False, cmap=cmap, vmin=vmin, vmax=vmax) #rescale_ticks(tax,new_scale=axis_scale,multiple = multiple, **tick_kwargs) tax.boundary() tax.ax.axis('off') tax.right_corner_label(tern_labels[0], fontsize=labelsize, va='center', offset=0.08 + add_labeloffset) tax.top_corner_label(tern_labels[1], fontsize=labelsize, va='center', offset=0.05 + add_labeloffset) tax.left_corner_label(tern_labels[2], fontsize=labelsize, va='center', offset=0.08 + add_labeloffset) tax._redraw_labels() return tax
def estimator_ternary_heatmap(scale, estimator, featurizer=None, feature_cols=None, scaler=None, use_X=None, style='triangular', labelsize=11, add_labeloffset=0, cmap=None, ax=None, figsize=None, vlim=None, metric='median', multiple=0.1, tick_kwargs={ 'tick_formats': '%.1f', 'offset': 0.02 }, tern_axes=['Ca', 'Al', 'Ba'], tern_labels=['CaO', 'Al$_2$O$_3$', 'BaO']): """ Generate ternary heatmap of ML predictions Args: scale: simplex scale estimator: sklearn estimator instance featurizer: featurizer instance feature_cols: subset of feature names used in model_eval scaler: sklearn scaler instance use_X: pre-calculated feature matrix; if passed, featurizer, feature_cols, and scaler are ignored style: heatmap interpolation style tern_axes: ternary axes. Only used for generating simplex compositions; ignored if use_X is supplied. Defaults to ['Ca','Al','Ba'] metric: if 'median', return point estimate. If 'iqr', return IQR of prediction """ coords, y = predict_simplex(estimator, scale, featurizer, feature_cols, scaler, use_X, tern_axes, metric) if vlim is None: vmin = min(y) vmax = max(y) else: vmin, vmax = vlim points = dict(zip([c[0:2] for c in coords], y)) if ax == None: fig, ax = plt.subplots(figsize=figsize) tfig, tax = ternary.figure(scale=scale, ax=ax) else: tax = ternary.TernaryAxesSubplot(scale=scale, ax=ax) tax.heatmap(points, style=style, colorbar=False, cmap=cmap, vmin=vmin, vmax=vmax) #rescale_ticks(tax,new_scale=axis_scale,multiple = multiple, **tick_kwargs) tax.boundary() tax.ax.axis('off') tax.right_corner_label(tern_labels[0], fontsize=labelsize, va='center', offset=0.08 + add_labeloffset) tax.top_corner_label(tern_labels[1], fontsize=labelsize, va='center', offset=0.05 + add_labeloffset) tax.left_corner_label(tern_labels[2], fontsize=labelsize, va='center', offset=0.08 + add_labeloffset) tax._redraw_labels() return tax
def plot_labeled_ternary(comps, values, ax=None, label_points=True, add_labeloffset=0, corner_labelsize=12, point_labelsize=11, point_labeloffset=[0, 0.01, 0], cmap=None, vlim=None, **scatter_kw): tern_axes = ['Ca', 'Al', 'Ba'] if ax is None: fig, ax = plt.subplots(figsize=(9, 8)) else: fig = ax.get_figure() #tfig, tax = ternary.figure(scale=1,ax=ax) tax = ternary.TernaryAxesSubplot(scale=1, ax=ax) points = [get_coords_from_comp(c, tern_axes) for c in comps] if vlim is None: vmin = min(values) vmax = max(values) else: vmin, vmax = vlim tax.scatter(points, c=values, cmap=cmap, vmin=vmin, vmax=vmax, **scatter_kw) tern_labels = ['CaO', 'Al$_2$O$_3$', 'BaO'] tax.right_corner_label(tern_labels[0], fontsize=corner_labelsize, va='center', offset=0.08 + add_labeloffset) tax.top_corner_label(tern_labels[1], fontsize=corner_labelsize, va='center', offset=0.05 + add_labeloffset) tax.left_corner_label(tern_labels[2], fontsize=corner_labelsize, va='center', offset=0.08 + add_labeloffset) tax.boundary(linewidth=1) #tax.clear_matplotlib_ticks() ax.axis('off') if label_points == True: for p, val in zip(points, values): if pd.isnull(val): disp = 'NA' else: disp = '{}'.format(int(round(val, 0))) tax.annotate(disp, p + point_labeloffset, size=point_labelsize, ha='center', va='bottom') #add_colorbar(fig,label='NH3 Production Rate (mmol/g$\cdot$h)',vmin=min(values),vmax=max(values),cbrect=[0.9,0.2,0.03,0.67]) tax._redraw_labels() return tax
def plot(self, selected_type_1, selected_type_2, selected_operation, min_color_scale, max_color_scale, is_percentage): #get config if self._model.config_data is not None: config = self._model.config_data else: config = self.default_config data = self._model.ternary_file_data # perform calculation data, title = self.calculate(data, selected_type_1, selected_type_2, selected_operation, is_percentage) # remove the inf and nan inf_nan_indexes = data.index[data['calculated'].isin( [np.nan, np.inf, -np.inf])].tolist() inf_indexes = data[data['calculated'].isin([np.inf, -np.inf])] data = data.drop(inf_nan_indexes) # get default min and max color scale if values are not defined if min_color_scale is None: min_color_scale = min(data["calculated"].values) else: min_color_scale = float(min_color_scale) if max_color_scale is None: max_color_scale = max(data["calculated"].values) else: max_color_scale = float(max_color_scale) # remove/replace data that is out of range # remove # data = data.loc[data.loc[:, 'calculated'] >= min_color_scale, :] # data = data.loc[data.loc[:, 'calculated'] <= max_color_scale, :] # replace data['calculated'].values[ data['calculated'].values < min_color_scale] = min_color_scale data['calculated'].values[ data['calculated'].values > max_color_scale] = max_color_scale points = np.array([data["x"].values, data["y"].values]).transpose() # colors map cm = LinearSegmentedColormap.from_list('Capacities', config["colors"], N=1024) # normalize data norm = Normalize(vmin=min_color_scale, vmax=max_color_scale) # get color based on color map data["calculated_norm"] = data["calculated"].apply(lambda x: norm(x)) data["calculated_color"] = data["calculated_norm"].apply( lambda x: cm(x)) colors = data["calculated_color"].values # Creates a ternary set of axes to plot the diagram from python-ternary fig, ax = plt.subplots(**config["figure"]) # fix aspect ratio. ax.set_aspect("equal") ax.set_title("Figure : {} | {}".format(plt.gcf().number, title), **config["title"]) tax = ternary.TernaryAxesSubplot(ax=ax, scale=1.0) tax.boundary() tax.gridlines(**config["gridlines"]) tax.scatter(points, c=colors, colormap=cm, vmin=min_color_scale, vmax=max_color_scale, **config["scatter"]) # colorbar cbar_ax = fig.axes[-1] cbar_ax.tick_params(**config["colorbar_tick_params"]) # ticks tax.ticks(**config["axis_ticks"]) # set axis labels tax.left_axis_label("1 - x - y", **config["axis_label"]) tax.right_axis_label("y", **config["axis_label"]) tax.bottom_axis_label("x", **config["axis_label"]) plt.axis('off') # if there are some inf index , so report it to user if not inf_indexes.empty: tooltip = inf_indexes[['x', 'y']].to_string() task_bar_data = { "color": "orange", 'message': "Figure : {} | Warning: calculation contains inf | {}".format( plt.gcf().number, title), 'tooltip': tooltip } self.task_bar_message.emit(task_bar_data) else: self.task_bar_message.emit({ "color": "green", "message": "Figure : {} | {}".format(plt.gcf().number, title), "tooltip": "{} rows were removed: {}".format( str(len(inf_nan_indexes)), ",".join(str(x) for x in inf_nan_indexes)) }) tax.show() tax.close()
def quat_slice_heatmap2(tuple_scale, zfunc, slice_val, zfunc_kwargs={}, style='triangular', slice_axis='Y', tern_axes=['Co', 'Fe', 'Zr'], labelsize=14, add_labeloffset=0, cmap=plt.cm.viridis, ax=None, figsize=None, vmin=None, vmax=None, Ba=1, multiple=0.1, tick_kwargs={ 'tick_formats': '%.1f', 'offset': 0.02 }): """ get zvals from formula instead of tup """ axis_scale = 1 - slice_val tuples = [] zvals = [] for tup in simplex_iterator(scale=tuple_scale): tuples.append(tup) formula = sliceformula_from_tuple(tup, slice_val=slice_val, slice_axis=slice_axis, tern_axes=tern_axes, Ba=Ba) zvals.append(zfunc(formula, **zfunc_kwargs)) if vmin == None: vmin = min(zvals) if vmax == None: vmax = max(zvals) d = dict(zip([t[0:2] for t in tuples], zvals)) if ax == None: fig, ax = plt.subplots(figsize=figsize) tfig, tax = ternary.figure(scale=tuple_scale, ax=ax) else: tax = ternary.TernaryAxesSubplot(scale=tuple_scale, ax=ax) tax.heatmap(d, style=style, colorbar=False, cmap=cmap, vmin=vmin, vmax=vmax) rescale_ticks(tax, new_scale=axis_scale, multiple=multiple, **tick_kwargs) tax.boundary() tax.ax.axis('off') tax.right_corner_label(tern_axes[0], fontsize=labelsize, offset=0.08 + add_labeloffset) tax.top_corner_label(tern_axes[1], fontsize=labelsize, offset=0.2 + add_labeloffset) tax.left_corner_label(tern_axes[2], fontsize=labelsize, offset=0.08 + add_labeloffset) tax._redraw_labels() return tax, vmin, vmax
def quat_slice_scatter(data, z, slice_start, slice_width=0, slice_axis='Y', tern_axes=['Co', 'Fe', 'Zr'], tern_labels=None, labelsize=14, tick_kwargs={ 'axis': 'lbr', 'linewidth': 1, 'tick_formats': '%.1f', 'offset': 0.03 }, nticks=5, add_labeloffset=0, cmap=plt.cm.viridis, ax=None, vmin=None, vmax=None, ptsize=8, figsize=None, scatter_kw={}): if slice_width == 0: df = data[data[slice_axis] == slice_start] else: df = data[(data[slice_axis] >= slice_start) & (data[slice_axis] < slice_start + slice_width)] points = df.loc[:, tern_axes].values colors = df.loc[:, z].values if len(df[pd.isnull(df[z])]) > 0: print('Warning: null values in z column') if vmin == None: vmin = np.min(colors) if vmax == None: vmax = np.max(colors) scale = 1 - (slice_start + slice_width / 2 ) #point coords must sum to scale print('Scale: {}'.format(scale)) #since each comp has different slice_axis value, need to scale points to plot scale ptsum = np.sum(points, axis=1)[np.newaxis].T scaled_pts = points * scale / ptsum if ax == None: fig, ax = plt.subplots(figsize=figsize) tfig, tax = ternary.figure(scale=scale, ax=ax) else: tax = ternary.TernaryAxesSubplot(scale=scale, ax=ax) if len(points) > 0: tax.scatter(scaled_pts, s=ptsize, cmap=cmap, vmin=vmin, vmax=vmax, colorbar=False, c=colors, **scatter_kw) tax.boundary(linewidth=1.0) tax.clear_matplotlib_ticks() multiple = scale / nticks #manually set ticks and locations - default tax.ticks behavior does not work ticks = list(np.arange(0, scale + 1e-6, multiple)) locations = ticks tax.ticks(multiple=multiple, ticks=ticks, locations=locations, **tick_kwargs) tax.gridlines(multiple=multiple, linewidth=0.8) if tern_labels is None: tern_labels = tern_axes tax.right_corner_label(tern_labels[0], fontsize=labelsize, offset=0.08 + add_labeloffset) tax.top_corner_label(tern_labels[1], fontsize=labelsize, offset=0.2 + add_labeloffset) tax.left_corner_label(tern_labels[2], fontsize=labelsize, offset=0.08 + add_labeloffset) tax._redraw_labels() ax.axis('off') return tax, vmin, vmax
def run_causal_simulations(): t = 0.8 # transmission rate b = 0.0 # background rate active_learning_problems = utils.create_active_learning_hyp_space(t=t, b=b) ig_model_predictions = [] self_teaching_model_predictions = [] pts_model_predictions = [] # get predictions of all three models for i, active_learning_problem in enumerate(active_learning_problems): gal = GraphActiveLearner(active_learning_problem) gal.update_posterior() eig = gal.expected_information_gain().tolist() ig_model_predictions.append(eig) gst = GraphSelfTeacher(active_learning_problem) gst.update_learner_posterior() self_teaching_posterior = gst.update_self_teaching_posterior() self_teaching_model_predictions.append(self_teaching_posterior) gpts = GraphPositiveTestStrategy(active_learning_problem) pts_model_predictions.append(gpts.positive_test_strategy()) figure, ax = plt.subplots() figure.set_size_inches(16, 5) ax.set_frame_on(False) ax.set_xticks([]) ax.set_yticks([]) points_one = [(0.33, 0.33, 0.33), (0.5, 0.5, 0)] points_two = [(0.33, 0.33, 0.33), (0.5, 0, 0.5)] points_three = [(0.33, 0.33, 0.33), (0, 0.5, 0.5)] for i in range(len(ig_model_predictions)): # make ternary plot ax = figure.add_subplot(3, 9, i + 1) tax = ternary.TernaryAxesSubplot(ax=ax) tax.set_title("Problem {}".format(i + 1), fontsize=10) tax.boundary(linewidth=2.0) tax.scatter([ig_model_predictions[i]], marker='o', color='red', label="Information Gain", alpha=0.8, s=40) tax.scatter([pts_model_predictions[i]], marker='d', color='green', label="Positive-Test Strategy", alpha=0.8, s=40) tax.scatter([self_teaching_model_predictions[i]], marker='s', color='blue', label="Self-Teaching", alpha=0.8, s=40) tax.line(points_one[0], points_one[1], color='black', linestyle=':') tax.line(points_two[0], points_two[1], color='black', linestyle=':') tax.line(points_three[0], points_three[1], color='black', linestyle=':') tax.left_axis_label("x_1", fontsize=20) tax.right_axis_label("x_2", fontsize=20) tax.bottom_axis_label("x_3", fontsize=20) tax.clear_matplotlib_ticks() ax.set_frame_on(False) handles, labels = ax.get_legend_handles_labels() plt.savefig('figures/causal_learning_simulations.png', dpi=600)
def plot(d, t, r, ax=None, title=False, scale=20, fontsize=20, blw=1.0, msz=8, small=False): def entropy(p): return sharma_mittal.sm_entropy(p, t, r) if ax == None: f, ax = plt.subplots() tax = ternary.TernaryAxesSubplot(ax=ax, scale=scale) tax.gridlines(color="white", multiple=1, linestyle="-", linewidth=0.1) # tax.gridlines(color="white", multiple=2, # linestyle="-", linewidth=0.2) numTicks = 10 scaleMtpl = int(scale / numTicks) tickLocs = np.append(np.arange(0, scale, scaleMtpl), [scale]) ticks = [round(tick, 2) for tick in np.linspace(0, 1.0, numTicks + 1)] tax.ticks(ticks=ticks, axis='lbr', axes_colors={ 'l': 'grey', 'r': 'grey', 'b': 'grey' }, offset=0.02, linewidth=1, locations=list(tickLocs), clockwise=True) tax.clear_matplotlib_ticks() # ax = plt.gca() ax.axis('off') tax.left_axis_label(r"$P(k_2) \quad \rightarrow$", fontsize=fontsize, offset=0.12) tax.right_axis_label(r"$P(k_1) \quad \rightarrow$", fontsize=fontsize, offset=0.12) tax.bottom_axis_label(r"$\leftarrow \quad P(k_3)$", fontsize=fontsize, offset=0.10) tax._redraw_labels() cmap = sns.cubehelix_palette(light=1, dark=0, start=2.5, rot=0, as_cmap=True, reverse=True) tax.heatmapf(entropy, boundary=True, style="triangular", cmap=cmap) tax.boundary(linewidth=blw) q_color = ['lime', 'fuchsia'] for i, q in enumerate(d.posterior): tax.line(q[0] * scale, q[1] * scale, linewidth=1., marker=['o', 'o'][i], color='k', linestyle="-", markersize=msz, markeredgecolor='k', markerfacecolor=q_color[i], markeredgewidth=1., label=r'$P(K|%s)$' % ['Q_1', 'Q_2'][i]) tax.line(d.prior * scale, d.prior * scale, linewidth=1., marker='s', color='k', linestyle="-", markersize=msz, markeredgecolor='k', markerfacecolor='yellow', markeredgewidth=1., label=r'$P(K)$') if not title == False: ax.text(0.5, 1.05, title, ha='center', va='center', transform=ax.transAxes) #Create Legend if not small: ax.text(0.00, 0.95, r'Order $\;\;(r) = %.2f$' % r, ha='left', va='center', transform=ax.transAxes) ax.text(0.00, 0.9, r'Degree $(t) = %.2f$' % t, ha='left', va='center', transform=ax.transAxes) if small: tax.legend(bbox_to_anchor=(0., 0.00, 1., -.102), ncol=3, mode="expand", borderaxespad=0., handletextpad=0, fancybox=True, shadow=False, frameon=True) else: tax.legend() #Remove Colorbar f = plt.gcf() f.delaxes(f.axes[-1]) plt.draw() return ax # d = design.design(size=(3,2)) # plot(d,t=1,r=1) # plt.show()
y0 = stiff_data[:, 31:34] y1 = stiff_data[:, 34:37] y2 = stiff_data[:, 37:40] y3 = stiff_data[:, 40:43] y4 = stiff_data[:, 43:46] x0 = nonstiff_data[:, 31:34] x1 = nonstiff_data[:, 34:37] x2 = nonstiff_data[:, 37:40] x3 = nonstiff_data[:, 40:43] x4 = nonstiff_data[:, 43:46] mark_val = len(x4[:, 0]) figure, ax = plt.subplots(1, 2, figsize=(12, 6)) tax1 = ternary.TernaryAxesSubplot(ax=ax[0]) eq_point = get_coordinates(y4[-1, :]) st_point = [] st_point.append(get_coordinates(y0[0, :])) st_point.append(get_coordinates(y1[0, :])) st_point.append(get_coordinates(y2[0, :])) st_point.append(get_coordinates(y3[0, :])) st_point.append(get_coordinates(y4[0, :])) tax1.boundary() tax1.gridlines(multiple=0.1, color="black") tax1.set_title("(a) Stiff Trajectories", fontsize=14) fontsize = 12 offset = 0.105 tax1.left_axis_label("$H^+$", fontsize=fontsize, offset=offset) tax1.right_axis_label("$H^*$", fontsize=fontsize, offset=offset)
def format_tern_ax( ax=None, labels=None, labelsides=False, title=None, ticks=False, grid=True, removespines=True, scale=100, fontsize="x-large", ticks_kws={ "axis": "lbr", "linewidth": 1, "multiple": 10 }, bound_kws={ "linewidth": 1, "zorder": 0.6 }, grid_kws={ "multiple": 10, "color": "grey", "linewidth": 1, "zorder": 0.5 }, ): """ Wrapper to return a ternary ax with optional grid, title, labels. *ax* is either None (creates a new fig and axes) or a matplotlib axes object. *labels* is a three-element array of strings in the order left, top, right for labelling ternary plot tips, or if *labelsides=True*, ternary plot sides. *scale* is an integer of subdivisions for ternary axes; *fontsize* controls the size of label and title text. *ticks_kws* is a dictionary of parameters for tax.ticks(), *bound_kws* for tax.boundary(), *grid_kws* for tax.gridlines(). *labels*, *title*, *ticks*, *grid* can be set to False/None to turn off each respective item. *removespines* removes the regular matplotlib spines from the axes. """ # Initiate ternary plot if ax: tax = ternary.TernaryAxesSubplot(ax=ax, scale=scale) else: fig, tax = ternary.figure(scale=scale) # Draw Boundary and Gridlines tax.boundary(**bound_kws) if grid: tax.gridlines(**grid_kws) # Set Axis labels and Title if title: tax.set_title(title, fontsize=fontsize) if labels: if labelsides: tax.left_axis_label(labels[0], fontsize=fontsize) tax.top_axis_label(labels[1], fontsize=fontsize) tax.right_axis_label(labels[2], fontsize=fontsize) else: tax.left_corner_label(labels[0], fontsize=fontsize) # offset makes it look more symmetric tax.top_corner_label(labels[1], fontsize=fontsize, offset=0.18) tax.right_corner_label(labels[2], fontsize=fontsize) # Set ticks if ticks: ternary_ax.ticks(**ticks_kws) # Remove default Matplotlib Axes if removespines: tax.get_axes().axis("off") return tax