def make_color_grouped_scatter_plot(data_frame, x_name, y_name, color_by, filename, colormap, x_function = 'dummy', y_function = 'dummy', color_function = 'dummy', legend = False, colorbar = True): ### Originally created for issue_21 def dummy(a): return a data_frame = data_frame.copy() p = Ppl(colormap, alpha=1) fig, ax = plt.subplots(1) #ax.set_autoscale_on(False) ax.set_xlim([eval(x_function)(min(data_frame[x_name])), eval(x_function)(max(data_frame[x_name]))]) ax.set_ylim([eval(y_function)(min(data_frame[y_name])), eval(y_function)(max(data_frame[y_name]))]) x_label = x_name.capitalize().replace('_', ' ') if x_function == 'log': x_label += ' (log)' y_label = y_name.capitalize().replace('_', ' ') if y_function == 'log': y_label += ' (log)' ax.set_xlabel(x_label) ax.set_ylabel(y_label) ax.xaxis.get_major_formatter().set_powerlimits((0, 1)) ax.yaxis.get_major_formatter().set_powerlimits((0, 1)) # Show the whole color range n_intervals = len(colormap.colors) if color_function == 'log': bins = np.logspace(np.log10( data_frame[color_by].min()), np.log10(data_frame[color_by].max()), n_intervals + 1, base = 10) else: bins = np.linspace(eval(color_function)(data_frame[color_by].min()), eval(color_function)(data_frame[color_by].max()), n_intervals + 1) data_frame['groups'] = pandas.cut(data_frame[color_by], bins=bins, labels = False) groups = pandas.cut(data_frame[color_by], bins=bins) bounds = [] for g in range(n_intervals): x = eval(x_function)(data_frame[data_frame.groups == g][x_name]) y = eval(y_function)(data_frame[data_frame.groups == g][y_name]) p.scatter(ax, x, y, label=str(groups.levels[g]), s = 5, linewidth=0) if legend: p.legend(ax, loc=0) #ax.set_title('prettyplotlib `scatter` example\nshowing default color cycle and scatter params') bounds = bins if colorbar: cmap = p.get_colormap().mpl_colormap norm = mpl.colors.BoundaryNorm(bounds, cmap.N) #ax2.set_ylabel(color_by.capitalize().replace('_', ' '), rotation='horizontal') #ax2.xaxis.get_major_formatter().set_powerlimits((0, 1)) #ax2.yaxis.get_major_formatter().set_powerlimits((0, 1)) ax2 = fig.add_axes([0.9, 0.1 , 0.03, 0.8]) cbar = mpl.colorbar.ColorbarBase(ax2, cmap=cmap, spacing='proportional', ticks=bounds, norm=norm, alpha=1, orientation='vertical') #cbar.ax.set_xticklabels(map(lambda x: '%.3g'%x, bounds))# vertically oriented colorbar cbar.ax.set_yticklabels([])# vertically oriented colorbar #for j, lab in enumerate(map(lambda lower, upper: '%.3g~%.3g'%(lower, upper), bounds[:-1], bounds[1::])): cbar.ax.text(0,1.02, '%.3g'%max(map(eval(color_function), bounds))) #for j, lab in enumerate(map(lambda upper: '< %.3g'%upper, bounds[1::])): # cbar.ax.text(.5, (2 * j + 1) / 8.0, lab, ha='center', va='center', rotation='vertical') #cbar.ax.set_xticklabels([str(int(t)) for t in bounds])# vertically oriented colorbar if color_function == 'log': label = color_by.capitalize().replace('_', ' ') + ' (log)' else: label = color_by.capitalize().replace('_', ' ') cbar.ax.set_ylabel(label, rotation='vertical') fig.savefig(filename) return ax, fig
def make_pretty_scatter_plot(x, y, xlabel, ylabel, filename, ax=None, fig=None): cmap = brewer2mpl.get_map('Set1', 'qualitative', 9) p = Ppl(cmap, alpha=0.3) if (not ax) and (not fig): fig, ax = plt.subplots(1) ax.set_xlabel(pfn(xlabel)) ax.set_ylabel(pfn(ylabel)) ax.set_xlim([min(x), max(x)]) ax.set_ylim([min(y), max(y)]) ax.yaxis.get_major_formatter().set_powerlimits((0, 1)) p.scatter(ax, x, y, s=5, linewidth=0) fig.savefig(filename)
def make_scatter_plot_for_labelled_data(data_frame, x_name, y_name, labels, filename, colormap, x_function = 'dummy', y_function = 'dummy', legend = False, point_size = 5, omit_largest = 0, labels_to_plot = []): ### Originally created for issue_28 if not labels_to_plot: labels_to_plot = set(labels) assert omit_largest < max(set(labels)), "omit_largest must be smaller than number of clusters" colors = colormap.mpl_colors def dummy(a): return a p = Ppl(colormap, alpha=1) fig, ax = plt.subplots(1) #ax.set_autoscale_on(False) ax.set_xlim([eval(x_function)(min(data_frame[x_name])), eval(x_function)(max(data_frame[x_name]))]) ax.set_ylim([eval(y_function)(min(data_frame[y_name])), eval(y_function)(max(data_frame[y_name]))]) #x_label = x_name.capitalize().replace('_', ' ') if x_function == 'log': x_label += ' (log)' #y_label = y_name.capitalize().replace('_', ' ') if y_function == 'log': y_label += ' (log)' ax.set_xlabel(fl(x_name)) ax.set_ylabel(fl(y_name)) ax.xaxis.get_major_formatter().set_powerlimits((0, 1)) ax.yaxis.get_major_formatter().set_powerlimits((0, 1)) # Show the whole color range cluster_size = map(lambda l: len(labels[labels == l]), set(labels)) sizes, groups = zip(*sorted(zip(cluster_size, set(labels)), reverse=True)) #print sizes, groups for order_to_plot, group in enumerate(list(groups)[-(len(groups)-omit_largest):]): #print order_to_plot, sizes[order_to_plot], group, cluster_size[group] if group in labels_to_plot: #print 'Plotting points in group %s'%group x = eval(x_function)(data_frame[labels == group][x_name]) y = eval(y_function)(data_frame[labels == group][y_name]) p.scatter(ax, x, y, label='C%s: %s'%(group, list(sizes)[order_to_plot]), s=point_size, linewidth=0, zorder=order_to_plot, color=colors[group]) if legend: legend = p.legend(ax, loc=0, fancybox=True, markerscale=5, frameon=False) legend.set_zorder(100) #ax.set_title('prettyplotlib `scatter` example\nshowing default color cycle and scatter params') fig.savefig(filename)