def update_graph(): ydata = microlens(time, params.values()) line.set_ydata(ydata) simple.set_ydata(microlens(time, [params['mag'], params['blend'], params['u0'], params['t0'], params['tE'], 0, 0, 0])) #ax0.set_ylim(ydata.max()+1, ydata.min()-1) colnorm = Normalize(vmin=16, vmax=19) colmap = ScalarMappable(norm=colnorm, cmap=plt.get_cmap('Reds_r')) xD, yD, xpE, ypE = projected_plan(time, params.values()) earth_projected_orbit.set_offsets(np.column_stack([xpE,ypE])) earth_projected_orbit.set_facecolor(colmap.to_rgba(ydata)) # earth_projected_orbit.set_ydata(ypE) defl_line.set_offsets(np.column_stack([xD,yD])) defl_line.set_facecolor(colmap.to_rgba(ydata)) vec_earth.set_data(*u_earth(params["ti"], params["delta_u"])) deflector_position = u_deflector(params["ti"], params["theta"], params["tE"], params["u0"], params["t0"]) vec_defle.set_data(*deflector_position) current_ampli.set_data(params["ti"], microlens(params["ti"], params.values())) einstein_radius.center = deflector_position[0][0], deflector_position[1][0] # co_lines.set_segments(np.reshape(np.column_stack(np.array([xD, yD, xpE, ypE])), (377,2,2))) # co_lines.set_array(ydata) # defl_line.set_ydata(yD) fig.canvas.draw_idle() fig2.canvas.draw_idle() fig3.canvas.draw_idle()
def plot_edgelength_over_time(el_t, eff, max_num_lines): leaves = sorted(el_t.keys(), key=lambda a: eff[L2N[a]]) if max_num_lines is not None: leaves = leaves[-max_num_lines:] min_eff = min(eff[L2N[l]] for l in leaves) max_eff = max(eff[L2N[l]] for l in leaves) max_time = max(el_t[l][-1][0] for l in leaves) norm = Normalize(vmin=min_eff, vmax=max_eff, clip=True) color_mapper = ScalarMappable(norm=norm, cmap=Reds) handles = [ Patch(color=color_mapper.to_rgba(e), label='%d Transmission(s)' % e) for e in (min_eff, max_eff) ] for l in leaves: pairs = el_t[l] x = [pairs[0][0]] y = [pairs[0][1]] # start (just a point) for t, el in pairs[1:]: x.append(t) y.append(y[-1]) # bring it forward x.append(t) y.append(el) # bring it up x.append(max_time) y.append(y[-1]) plt.plot(x, y, color=color_mapper.to_rgba(eff[L2N[l]])) plt.legend(handles=handles, bbox_to_anchor=(0.995, 0.995), loc=1, borderaxespad=0.) plt.xlabel('Time') plt.ylabel('Edge Length') plt.title('Edge Length vs. Time') plt.tight_layout() plt.show()
def update_plot(data): '''Update plot for animation.''' projections, name, target_position, centroid_position = data projection_xz, projection_yz, projection_xy = projections target_x, target_y, target_z = target_position centroid_x, centroid_y = centroid_position # Update target position and annotation from radar on plots. target_ant_xz.set_position(xy=(target_x, target_z)) target_pt_xz.set_data(target_x, target_z) target_ant_yz.set_position(xy=(target_y, target_z)) target_pt_yz.set_data(target_y, target_z) target_ant_xy.set_position(xy=(target_x, target_y)) target_pt_xy.set_data(target_x, target_y) # Update name and postion annotations of centroid from camera on plots centroid_ant_xz.set_text(s=name) centroid_ant_xz.set_position(xy=(centroid_x, target_z)) centroid_pt_xz.set_data(centroid_x, target_z) centroid_ant_yz.set_text(s=name) centroid_ant_yz.set_position(xy=(centroid_y, target_z)) centroid_pt_yz.set_data(centroid_y, target_z) centroid_ant_xy.set_text(s=name) centroid_ant_xy.set_position(xy=(centroid_x, centroid_y)) centroid_pt_xy.set_data(centroid_x, centroid_y) # Update image colors according to return signal strength on plots. sm = ScalarMappable(cmap='coolwarm') signal_pts_xz.set_color(sm.to_rgba(projection_xz.T.flatten())) sm = ScalarMappable(cmap='coolwarm') signal_pts_yz.set_color(sm.to_rgba(projection_yz.T.flatten())) # Scale xy image data relative to target distance. signal_pts_xy.set_extent( [v * target_z / (zmax - zmin) for v in [xmin, xmax, ymin, ymax]]) # Rotate xy image if radar horizontal since x and y axis are rotated 90 deg CCW. if RADAR_HORIZONTAL: projection_xy = np.rot90(projection_xy) sm = ScalarMappable(cmap='coolwarm') signal_pts_xy.set_data(sm.to_rgba(projection_xy)) return (signal_pts_xz, target_ant_xz, target_pt_xz, centroid_ant_xz, centroid_pt_xz, signal_pts_yz, target_ant_yz, target_pt_yz, centroid_ant_yz, centroid_pt_yz, signal_pts_xy, target_ant_xy, target_pt_xy, centroid_ant_xy, centroid_pt_xy)
def plot_res(self, ax, do_label=True, tag_leafs=False, zlim=False, cmap='jet_r'): from matplotlib.pyplot import get_cmap from matplotlib import patheffects from matplotlib.colors import LogNorm from matplotlib.cm import ScalarMappable # Create a color map using either zlim as given or max/min resolution. cNorm = LogNorm(vmin=self.dx_min, vmax=self.dx_max, clip=True) cMap = ScalarMappable(cmap=get_cmap(cmap), norm=cNorm) dx_vals = {} for key in self.tree: if self[key].isLeaf: color = cMap.to_rgba(self[key].dx) dx_vals[self[key].dx] = 1.0 #if self[key].dx in res_colors: # dx_vals[self[key].dx] = 1.0 # color=#res_colors[self[key].dx] #else: # color='k' self[key].plot_res(ax, fc=color, label=key * tag_leafs) if do_label: ax.annotate('Resolution:', [1.02, 0.99], xycoords='axes fraction', color='k', size='medium') for i, key in enumerate(sorted(dx_vals.keys())): #dx_int = log2(key) if key < 1: label = '1/%i' % (key**-1) else: label = '%i' % key ax.annotate('%s $R_{E}$' % label, [1.02, 0.87 - i * 0.1], xycoords='axes fraction', color=cMap.to_rgba(key), size='x-large', path_effects=[ patheffects.withStroke(linewidth=1, foreground='k') ])
def apply(self, data, mask): """ Apply the colorizer to a data/mask set. TODO: Clean this up a bit """ self.rgba = None if len(self.bands) == 1: # Singleband pseudocolor data = data[0, :, :] else: # Multiband RGB. No colorizing to do, just transpose the # data array and add the mask band. Mask may be discarded # later when using jpg as output, for example. self.rgba = np.dstack((np.transpose(data, (1, 2, 0)), mask)) return self.rgba if self.interp == 'linear': norm = Normalize(self.ranges[0], self.ranges[-1]) sm = ScalarMappable(norm=norm, cmap=self.colormap) self.rgba = sm.to_rgba(data, bytes=True) if self.interp == 'discrete': norm = BoundaryNorm(boundaries=self.ranges, ncolors=256) sm = ScalarMappable(norm=norm, cmap=self.colormap) self.rgba = sm.to_rgba(data, bytes=True) if self.interp == 'exact': norm = NoNorm() sm = ScalarMappable(norm=norm, cmap=self.colormap) # Copy and mask the entire array. tmp_data = data.copy() tmp_data.fill(0) tmp_mask = mask.copy() tmp_mask.fill(0) # Reclassify the data for n, r in enumerate(self.ranges): ix = np.logical_and((data == r), (mask == 255)) tmp_data[ix] = n + 1 tmp_mask[ix] = 255 self.rgba = sm.to_rgba(tmp_data, bytes=True) mask = tmp_mask self.rgba[:, :, 3] = mask return self.rgba
def visualize_prediction_and_explanation(self, x, y, prob, grad): fig = plt.figure(figsize=(20, 8)) G = gridspec.GridSpec(6, 10) mag = np.abs(grad).max() pred = int(round(prob)) norm = Normalize(vmin=-1, vmax=1) sm = ScalarMappable(norm=norm, cmap=plt.cm.bwr) xmin = self.Xv.min(axis=0) xmax = self.Xv.max(axis=0) xmed = (xmin + xmax) * 0.5 plt.subplot(G[:,:2]) self.explanation_barchart(grad) plt.title('$\hat{y}='+str(pred)+'$') plt.gca().set_yticklabels(['{}: {:.2f}'.format(n, x[j]) for j,n in enumerate(self.feature_names)], fontsize=9) for i, label in enumerate(sorted(self.feature_names)): divisor = 1.0 if label == 'age': divisor = 365.0 j = self.feature_names.index(label) weight = (grad[j]/mag)**2 * np.sign(grad[j]) plt.subplot(G[i//8, 2+i%8], axisbg=sm.to_rgba(weight)) plt.hist(self.Xv[:,j]/divisor, bins=25, alpha=0.5, color='blue') plt.gca().set_yticklabels([]) plt.axvline(x[j]/divisor, ls='--', lw=2, color='black') plt.tick_params(axis='both', which='major', labelsize=8) plt.xticks(np.array([xmin[j], xmed[j], xmax[j]])/divisor) plt.title(label[:16] + ": {:.1f}".format(x[j]/divisor), fontsize=8) fig.suptitle('Prediction = {} ({:.1%}), True Outcome = {}'.format( self.label_names[pred], prob, self.label_names[y]), fontsize=16) plt.tight_layout(rect=[0, 0.03, 1, 0.95]) plt.show()
def make_graph(data): from matplotlib.cm import jet, ScalarMappable from matplotlib.colors import Normalize g = nx.Graph() cnorm = Normalize(vmin=1, vmax=241) smap = ScalarMappable(norm=cnorm, cmap=jet) edge_list = [] for k in data: tk = k.split('_') if len(tk) != 5: g.add_node(k) else: a,b = tk[1],tk[3] g.add_node(a) g.add_node(b) g.add_edge(a,b) pos = nx.spring_layout(g) nxcols,glabels = [],{} for i,node in enumerate(g.nodes()): if '_' not in node: nxcols.append(smap.to_rgba(int(node))) else: nxcols.append((0,1,0,0)) nx.draw_networkx_nodes(g,pos,node_color=nxcols) nx.draw_networkx_labels(g,pos) nx.draw_networkx_edges(g,pos) plt.show() return
def calctime(dval, maxtime=50.0): logger = logging.getLogger('dataval') logger.info('Plotting calculation times for photometry...') for cadence in dval.cadences: star_vals = dval.search_database( select=['diagnostics.stamp_resizes', 'diagnostics.elaptime'], search=[ f'cadence={cadence:d}', f'diagnostics.elaptime <= {maxtime:f}' ]) if not star_vals: continue et = np.array([star['elaptime'] for star in star_vals], dtype='float64') resize = np.array([star['stamp_resizes'] for star in star_vals], dtype='int32') maxresize = int(np.max(resize)) fig, ax = plt.subplots(figsize=plt.figaspect(0.5)) norm = Normalize(vmin=-0.5, vmax=maxresize + 0.5) scalarMap = ScalarMappable(norm=norm, cmap=plt.get_cmap('tab10')) # Calculate KDE of full dataset: kde1 = KDE(et) kde1.fit(kernel='gau', gridsize=1024) # Calculate KDEs for different number of stamp resizes: for jj in range(maxresize + 1): kde_data = et[resize == jj] if len(kde_data): kde2 = KDE(kde_data) kde2.fit(kernel='gau', gridsize=1024) rgba_color = scalarMap.to_rgba(jj) ax.fill_between(kde2.support, 0, kde2.density, color=rgba_color, alpha=0.5, label=f'{jj:d} resizes') ax.plot(kde1.support, kde1.density, color='k', lw=2, label='All') ax.set_xlim([0, maxtime]) ax.set_ylim(bottom=0) ax.xaxis.set_major_locator(MultipleLocator(5)) ax.xaxis.set_minor_locator(MultipleLocator(1)) ax.set_xlabel('Calculation time (sec)') ax.legend(loc='upper right') fig.savefig(os.path.join(dval.outfolder, f'calctime_c{cadence:04d}')) if not dval.show: plt.close(fig)
def make_graph(data): from matplotlib.cm import jet, ScalarMappable from matplotlib.colors import Normalize g = nx.Graph() cnorm = Normalize(vmin=1, vmax=241) smap = ScalarMappable(norm=cnorm, cmap=jet) edge_list = [] for k in data: tk = k.split('_') if len(tk) != 5: g.add_node(k) else: a, b = tk[1], tk[3] g.add_node(a) g.add_node(b) g.add_edge(a, b) pos = nx.spring_layout(g) nxcols, glabels = [], {} for i, node in enumerate(g.nodes()): if '_' not in node: nxcols.append(smap.to_rgba(int(node))) else: nxcols.append((0, 1, 0, 0)) nx.draw_networkx_nodes(g, pos, node_color=nxcols) nx.draw_networkx_labels(g, pos) nx.draw_networkx_edges(g, pos) plt.show() return
def plot_XCO2(centre): import matplotlib.pyplot as plt from mpl_toolkits.basemap import Basemap from matplotlib.colors import Normalize from matplotlib.cm import ScalarMappable lats = np.loadtxt('../earth_data/XCO2_lats.txt') lons = np.loadtxt('../earth_data/XCO2_lons.txt') xco2 = np.loadtxt('../earth_data/XCO2.txt') lons, lats = np.meshgrid(lons, lats) sm = ScalarMappable(Normalize(vmin=392, vmax=408), cmap='RdYlBu_r') levs = np.linspace(392, 408, 256) clevs = [sm.to_rgba(lev) for lev in levs] projection = Basemap(projection='ortho', lat_0=centre[0], lon_0=centre[1], resolution='l') projection.drawcoastlines() x, y = projection(lons, lats) ortho_mask = np.ma.masked_greater(x, 1e15).mask xco2_masked = np.ma.array(xco2, mask=ortho_mask) ctr = plt.contourf(x, y, xco2_masked, levs, colors=clevs, extend='both') cbar = plt.colorbar(ctr, orientation='horizontal') cbar.set_label('Xco$_2$ [ppm]') plt.show()
def main(): rospy.init_node('publish_custom_point_cloud') publisher = rospy.Publisher('/custom_point_cloud', PointCloud2, queue_size=1000) map_publisher = rospy.Publisher('/map', PointCloud2, queue_size=1000) for all_points in pd.read_hdf( "/home/cglwn/Documents/Datasets/Michigan-NCLT/velodyne_data/2013-01-10-velodyne_hits.hdf", "velodyne_hits", chunksize=1000): grouped = all_points.groupby("unix_time") for time, points in grouped: header = Header(frame_id='/velodyne', stamp=rospy.Time.from_sec(time / 1e6)) intensity_cmap = ScalarMappable(cmap="viridis") rgb = intensity_cmap.to_rgba(points["intensity"] / 255.0)[:, :3] point_matrix = points[["LI_x", "LI_y", "LI_z"]].values point_matrix[:, 2] = -point_matrix[:, 2] # points_with_color = np.column_stack([point_matrix, rgb]) # point_cloud = pc2.create_cloud(header, FIELDS, points_with_color) # publisher.publish(point_cloud) res = get_tx_at_time(time) if res is not None: x, y, z, yaw, pitch, roll = res point_matrix = transform(point_matrix, x, y, z, yaw, pitch, roll) header = Header(frame_id='/map', stamp=rospy.Time.from_sec(time / 1e6)) points_with_color = np.column_stack([point_matrix, rgb]) point_cloud = pc2.create_cloud(header, FIELDS, points_with_color) map_publisher.publish(point_cloud)
def cont_palette(values: np.ndarray) -> Tuple[np.ndarray, ScalarMappable]: cm = copy(plt.get_cmap(cmap)) cm.set_bad("grey") sm = ScalarMappable(cmap=cm, norm=Normalize(vmin=np.nanmin(values), vmax=np.nanmax(values))) return np.array([to_hex(v) for v in (sm.to_rgba(values))]), sm
def show_elevation(self): """ Draw a contour plot of the elevation data. Due to the large size of the dataest, rhis takes quite a long time. """ from matplotlib.pyplot import colorbar, show from mpl_toolkits.basemap import Basemap from matplotlib.colors import Normalize from matplotlib.cm import ScalarMappable projection = Basemap(projection='cyl', resolution='c') projection.drawcoastlines() lons, lats = np.meshgrid(self.lons, self.lats) levs = np.linspace(-8000., 8000., 256) sm = ScalarMappable(Normalize(vmin=-8000., vmax=8000.), cmap='jet') clevs = [sm.to_rgba(lev) for lev in levs] elevctr = projection.contourf(lons, lats, self.elev, levs, colors=clevs, extend='both') cbar = colorbar(elevctr, orientation='horizontal', aspect=40) cbar.set_ticks(np.arange(-8000., 8001., 1000.)) cbar.set_label("Surface Elevation [m]") show()
def make_plot(self, val, cct, clm, cmap='jet', cmap_norm=None, cmap_vmin=None, cmap_vmax=None): # make color mapping smap = ScalarMappable(cmap_norm, cmap) smap.set_clim(cmap_vmin, cmap_vmax) smap.set_array(val) bin_colors = smap.to_rgba(val) # make patches patches = [] for i_c, i_clm in enumerate(clm): patches.append( Rectangle((i_clm[0], i_clm[2]), i_clm[1] - i_clm[0], i_clm[3] - i_clm[2])) patches_colle = PatchCollection(patches) patches_colle.set_edgecolor('face') patches_colle.set_facecolor(bin_colors) return patches_colle, smap
def create_heatmap(xs, ys, imageSize, blobSize, cmap): blob = Image.new('RGBA', (blobSize * 2, blobSize * 2), '#000000') blob.putalpha(0) colour = 255 / int(math.sqrt(len(xs))) draw = ImageDraw.Draw(blob) draw.ellipse((blobSize / 2, blobSize / 2, blobSize * 1.5, blobSize * 1.5), fill=(colour, colour, colour)) blob = blob.filter(ImageFilter.GaussianBlur(radius=blobSize / 2)) heat = Image.new('RGBA', (imageSize, imageSize), '#000000') heat.putalpha(0) xScale = float(imageSize - 1) / (max(xs) - min(xs)) yScale = float(imageSize - 1) / (min(ys) - max(ys)) xOff = min(xs) yOff = max(ys) for i in range(len(xs)): xPos = int((xs[i] - xOff) * xScale) yPos = int((ys[i] - yOff) * yScale) blobLoc = Image.new('RGBA', (imageSize, imageSize), '#000000') blobLoc.putalpha(0) blobLoc.paste(blob, (xPos - blobSize, yPos - blobSize), blob) heat = ImageChops.add(heat, blobLoc) norm = Normalize(vmin=min(min(heat.getdata())), vmax=max(max(heat.getdata()))) sm = ScalarMappable(norm, cmap) heatArray = pil_to_array(heat) rgba = sm.to_rgba(heatArray[:, :, 0], bytes=True) rgba[:, :, 3] = heatArray[:, :, 3] coloured = Image.fromarray(rgba, 'RGBA') return coloured
def color_scatter(self, lats, lngs, values=None, colormap='coolwarm', size=None, marker=False, s=None, **kwargs): def rgb2hex(rgb): """ Convert RGBA or RGB to #RRGGBB """ rgb = list(rgb[0:3]) # remove alpha if present rgb = [int(c * 255) for c in rgb] hexcolor = '#%02x%02x%02x' % tuple(rgb) return hexcolor if values is None: colors = [None for _ in lats] else: cmap = plt.get_cmap(colormap) norm = Normalize(vmin=min(values), vmax=max(values)) scalar_map = ScalarMappable(norm=norm, cmap=cmap) colors = [rgb2hex(scalar_map.to_rgba(value)) for value in values] for lat, lon, c in zip(lats, lngs, colors): self.scatter(lats=[lat], lngs=[lon], c=c, size=size, marker=marker, s=s, **kwargs)
def main(args): if len(args) == 0: x, y = numpy.meshgrid(numpy.linspace(-4, 4, 81), numpy.linspace(-4, 4, 81)) fake_data = 8 * numpy.exp(-(x**2 + y**2)) plt.imshow(fake_data, cmap=exfel_colormap) elif len(args) == 1: fake_data = numpy.load(args[0]) imgFT_cmap_interface = ScalarMappable(norm=matplotlib.colors.LogNorm(), cmap=exfel_colormap) # Convert to RGBA (uint8 based). xfel_imgFT = imgFT_cmap_interface.to_rgba(fake_data, bytes=True) # Swap rgba to bgra (because different convensions between matplotlib and cv) cv2_r, cv2_g, cv2_b, cv2_a = cv2.split(xfel_imgFT) xfel_imgFT = cv2.merge((cv2_b, cv2_g, cv2_r, cv2_a)) namedWindow('Image') while True: imshow('Image', xfel_imgFT) k = waitKey(500) if k == 27: break # esc to quit destroyAllWindows()
def render(self, model): # create a mapping from fertility values to colours colour_map = ScalarMappable(norm=Normalize( vmin=-0.5, vmax=np.max(model.grid.fertility) * 1.2), cmap='Greens') # create list of hex colour strings for each column of the grid colours = [] for x in range(self.grid_width): colours.append( to_hex(colour_map.to_rgba(model.grid.fertility[0, x]))) grid_state = defaultdict(list) for x in range(model.grid.width): for y in range(model.grid.height): portrayal = { "Shape": "rect", "x": x, "y": y, "w": 1, "h": 1, "Color": colours[x], "Filled": "true", "Layer": 0 } grid_state[0].append(portrayal) cell_objects = model.grid.get_cell_list_contents([(x, y)]) for obj in cell_objects: portrayal = self.portrayal_method(obj) if portrayal: portrayal["x"] = x portrayal["y"] = y grid_state[portrayal["Layer"]].append(portrayal) return grid_state
def apply_cmap(data, cmap="gray", clim="auto", bytes=False): """ Apply a matplotlib colormap to a 2D or 3D numpy array and return the rgba data in float or uint8 format. data : 2D or 3D numpy array cmap : string denoting a matplotlib colormap Colormap used for displaying frames from data. Defaults to 'gray'. clim : length-2 list, tuple, or ndarray, or string Upper and lower intensity limits to display from data. Defaults to 'auto' If clim='auto', the min and max of data will be used as the clim. Before applying the colormap, data will be clipped from clim[0] to clim[1]. bytes : bool, defaults to False If true, return values are uint8 in the range 0-255. If false, return values are float in the range 0-1 """ from matplotlib.colors import Normalize from matplotlib.cm import ScalarMappable from numpy import array if clim == "auto": clim = data.min(), data.max() sm = ScalarMappable(Normalize(*clim, clip=True), cmap) rgba = array([sm.to_rgba(d, bytes=bytes) for d in data]) return rgba
def scatter3d(x,y,z, cs, colorsMap='jet'): cm = plt.get_cmap(colorsMap) cNorm = Normalize(vmin=min(cs), vmax=max(cs)) scalarMap = ScalarMappable(norm=cNorm, cmap=cm) ax.scatter(x, y, z, c=scalarMap.to_rgba(cs), s=5, linewidth=0) scalarMap.set_array(cs) plt.show()
def get_color_map(self, levels): """Returns gradient of color from green to red. """ sm = ScalarMappable(cmap='RdYlGn_r') normed_levels = levels / np.max(levels) colors = 255 * sm.to_rgba(normed_levels)[:, :3] return ['#%02x%02x%02x' % (r, g, b) for r,g,b in colors]
def plot(self, cmap="copper", **plt_args): """Create a ternay plot.""" self._create_vertices() S = self.x + self.y + self.z x_cartesian = 0.5 * (2 * self.y + self.z) / S y_cartesian = 0.5 * np.sqrt(3) * self.z / S if "marker" not in plt_args.keys(): plt_args["marker"] = "v" if self.color is not None: norm = Normalize(vmin=np.min(self.color), vmax=np.max(self.color)) smap = ScalarMappable(norm=norm, cmap=cmap) color = [smap.to_rgba(c) for c in self.color] plt_args["c"] = color im = self.ax.scatter(x_cartesian, y_cartesian, **plt_args) # Remove spines self.ax.spines["right"].set_visible(False) self.ax.spines["top"].set_visible(False) self.ax.spines["left"].set_visible(False) self.ax.spines["bottom"].set_visible(False) self.ax.xaxis.set_ticklabels([]) self.ax.yaxis.set_ticklabels([]) self.ax.set_xticks([]) self.ax.set_yticks([]) if self.color is not None: ticks = np.linspace(np.min(self.color), np.max(self.color), 4) smap.set_array([]) cbar = self.fig.colorbar(smap, ax=self.ax, label=self.cbar_label)
def main(): save_dir = '/data/one punch/how people walk skeleton' # generatemask = generateMask().cuda().half().eval() model = creatModel().cuda().eval().half() mytransform = transforms.Compose([ transforms.ToTensor(), # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) # state = torch.load(load_mask_name) # generatemask.load_state_dict(state['state_dict']) state = torch.load(load_model_name) model.load_state_dict(state['state_dict']) # loss_background = Costomer_CrossEntropyLoss().cuda() dataLoader = data.DataLoader( myImageDataset('/data/one punch/how people walk'), batch_size=1, num_workers=1) for step, [x_, name] in enumerate(dataLoader): bx_ = x_.cuda().half() result = model(bx_) skeleton = result[1] cm = ScalarMappable(Normalize(0, nSkeleton_MPII - 1)) for i in range(skeleton.shape[0]): skeleton_inner = skeleton[i] skeleton_inner = torch.argmax(skeleton_inner, dim=0) if skeleton_inner.max() != 0: skeleton_inner = cm.to_rgba(skeleton_inner.cpu(), bytes=True)[:, :, :3] skeleton_image = Image.fromarray(skeleton_inner) skeleton_image.save(path.join(save_dir, name[i])) print('yyy')
def digit_to_rgb(X, scaling=3, shape = (), cmap = 'binary'): ''' Takes as input an intensity array and produces a rgb image due to some color map Parameters ---------- X : numpy.ndarray intensity matrix as array of shape [M x N] scaling : int optional. positive integer value > 0 shape: tuple or list of its , length = 2 optional. if not given, X is reshaped to be square. cmap : str name of color map of choice. default is 'binary' Returns ------- image : numpy.ndarray three-dimensional array of shape [scaling*H x scaling*W x 3] , where H*W == M*N ''' sm = ScalarMappable(cmap = cmap) image = sm.to_rgba(enlarge_image(vec2im(X,shape), scaling))[:,:,0:3] return image
def free_energy_vs_comp(): from matplotlib import pyplot as plt from matplotlib.cm import ScalarMappable from matplotlib.colors import Normalize temps = [400, 500, 600, 650, 700, 750, 800] temps = [500, 600, 700, 800] fig = plt.figure() ax = fig.add_subplot(1, 1, 1) cNorm = Normalize(vmin=400, vmax=800) scalarMap = ScalarMappable(norm=cNorm, cmap="copper") for T in temps: fname = "data/pseudo_binary_free/adaptive_bias{}K_-650mev.h5".format(T) with h5.File(fname, 'r') as hfile: x = np.array(hfile["x"])/2000.0 betaG = -np.array(np.array(hfile["bias"])) value = (T - 400.0)/400.0 color = scalarMap.to_rgba(T) res = linregress(x, betaG) slope = res[0] interscept = res[1] betaG -= (slope*x + interscept) betaG -= betaG[0] ax.plot(x, betaG, drawstyle="steps", color=color) ax.set_xlabel("Fraction MgSi") ax.set_ylabel("\$\\beta \Delta G\$") scalarMap.set_array([400.0, 800]) cbar = fig.colorbar(scalarMap, orientation="horizontal", fraction=0.07, anchor=(1.0, 0.0)) cbar.set_label("Temperature (K)") ax.spines["right"].set_visible(False) ax.spines["top"].set_visible(False) plt.show()
def reflection(): from matplotlib import pyplot as plt from matplotlib.cm import ScalarMappable from matplotlib.colors import Normalize temps = [600, 700, 800] fig = plt.figure() ax = fig.add_subplot(1, 1, 1) cNorm = Normalize(vmin=600, vmax=800) scalarMap = ScalarMappable(norm=cNorm, cmap="copper") for T in temps: fname = "data/diffraction/layered_bias{}K.h5".format(T) with h5.File(fname, 'r') as hfile: x = np.array(hfile["x"]) betaG = -np.array(np.array(hfile["bias"])) color = scalarMap.to_rgba(T) res = linregress(x, betaG) slope = res[0] interscept = res[1] betaG -= np.min(betaG) ax.plot(x, betaG, drawstyle="steps", color=color, lw=2) ax.set_xlim([0.25, 0.5]) ax.set_xlabel("Normalised diffraction intensity") ax.set_ylabel("\$\\beta \Delta G\$") scalarMap.set_array([500.0, 800]) #cbar = fig.colorbar(scalarMap, orientation="horizontal", fraction=0.07, anchor=(1.0, 0.0)) #cbar.set_label("Temperature (K)") ax.spines["right"].set_visible(False) ax.spines["top"].set_visible(False) data = load_fit_data() #ax.plot(data["MCpoints"]["x"], data["MCpoints"]["y"]) ax.plot(data["Fitted"]["x"], data["Fitted"]["y"], ls="--", color="black") plt.show()
def _plotimage(imagedata, out, i): sm = ScalarMappable(norm=Normalize(vmin=imagedata.vmin, vmax=imagedata.vmax), cmap=imagedata.cmap) nz = imagedata.imgs[i] > 0 out[nz, :] = out[nz, :] * (1 - imagedata.alpha) + imagedata.alpha * sm.to_rgba( imagedata.imgs[i], bytes=False)[nz, :3]
def generate_colors_palette(cmap="viridis", n_colors=10, alpha=1.0): """Generate colors from matplotlib colormap; pass list to use exact colors""" if isinstance(cmap, list): colors = [list(to_rgba(color, alpha=alpha)) for color in cmap] else: scalar_mappable = ScalarMappable(cmap=cmap) colors = scalar_mappable.to_rgba(range(n_colors), alpha=alpha).tolist() return colors
def paintChain(self, chainID, **kw): colormapname = kw.get('colormapname', self.colormapname) Norm = self.norms[kw.get('norm', 'global')] cmap = ScalarMappable(Norm, get_cmap(colormapname)) for resnum,value in self.chains[chainID].iteritems(): colorname = "chain{}res{}".format(chainID, resnum) cmd.set_color(colorname, cmap.to_rgba(value)[:-1]) cmd.color(colorname, "chain {} and resi {}".format(chainID, resnum))
def show_cam_on_image(img, mask): sc = ScalarMappable(cmap=cm.jet) mask = sc.to_rgba(mask)[:, :, :-1] heatmap = np.float32(mask) cam = heatmap + np.float32(img) cam /= np.max(cam) cam = np.nan_to_num(cam) imshow(cam)
def getcolor(self, V, F): dS = numpy.empty(len(F)) for i, f in enumerate(F): v = V[f][0] dS[i] = v[0] * v[0] + v[1] * v[1] + v[2] * v[2] cmap = ScalarMappable(cmap='jet') cmap.set_array(dS) return cmap, cmap.to_rgba(dS)
def _check_cmap_rgb_vals(vals, cmap, vmin=0, vmax=1): """Helper function to check RGB values of color images""" from matplotlib.colors import Normalize from matplotlib.cm import ScalarMappable norm = Normalize(vmin, vmax) sm = ScalarMappable(norm=norm, cmap=cmap) for val, rgb_expected in vals: rgb_actual = sm.to_rgba(val)[:-1] assert_allclose(rgb_actual, rgb_expected, atol=1e-5)
def plot_col(col, label='', cmap='Blues', **kwargs): map = ScalarMappable(Normalize(-7, 7), cmap) for i in [0, 7]: t = traj(i, col) plt.plot(hours, t, label='{} {}'.format(label, concs[i]), color=map.to_rgba(i), **kwargs)
def show_cam_on_image(img, mask): sc = ScalarMappable(cmap=cm.jet) mask = sc.to_rgba(mask)[:, :, :-1] heatmap = np.float32(mask) cam = heatmap + np.float32(img) cam /= np.max(cam) from skimage.io import imshow imshow(cam) return heatmap, cam
def plot_events(mapobj,axisobj,catalog,label= None, color='depth', pretty = False, colormap=None, llat = -90, ulat = 90, llon = -180, ulon = 180, figsize=(16,24), par_range = (-90., 120., 30.), mer_range = (0., 360., 60.), showHour = False, M_above = 0.0, location = 'World', min_size=1, max_size=8,**kwargs): '''Simplified version of plot_event''' import matplotlib.pyplot as plt from matplotlib.colors import Normalize from matplotlib.cm import ScalarMappable lats, lons, mags, times, labels, colors = get_event_info(catalog, M_above, llat, ulat, llon, ulon, color, label) min_color = min(colors) max_color = max(colors) if colormap is None: if color == "date": colormap = plt.get_cmap() else: # Choose green->yellow->red for the depth encoding. colormap = plt.get_cmap("RdYlGn_r") scal_map = ScalarMappable(norm=Normalize(min_color, max_color), cmap=colormap) scal_map.set_array(np.linspace(0, 1, 1)) x, y = mapobj(lons, lats) min_mag = 0 max_mag = 10 if len(mags) > 1: frac = [(_i - min_mag) / (max_mag - min_mag) for _i in mags] magnitude_size = [(_i * (max_size - min_size)) ** 2 for _i in frac] #magnitude_size = [(_i * min_size) for _i in mags] #print magnitude_size colors_plot = [scal_map.to_rgba(c) for c in colors] else: magnitude_size = 15.0 ** 2 colors_plot = "red" quakes = mapobj.scatter(x, y, marker='o', s=magnitude_size, c=colors_plot, zorder=10) #mapobj.drawmapboundary(fill_color='aqua') #mapobj.drawparallels(np.arange(-90,90,30),labels=[1,0,0,0]) #mapobj.drawmeridians(np.arange(mapobj.lonmin,mapobj.lonmax+30,60),labels=[0,0,0,1]) # if len(mags) > 1: # cb = mpl.colorbar.ColorbarBase(ax=axisobj, cmap=colormap, orientation='vertical') # cb.set_ticks([0, 0.25, 0.5, 0.75, 1.0]) # color_range = max_color - min_color # cb.set_ticklabels([_i.strftime('%Y-%b-%d') if color == "date" else '%.1fkm' % (_i) # for _i in [min_color, min_color + color_range * 0.25, # min_color + color_range * 0.50, # min_color + color_range * 0.75, max_color]]) return quakes
def plotTciData(self, ax): rdata = self.elecTree.getNode('\electrons::top.tci.results:rad').data() sm = ScalarMappable() rhoColor = sm.to_rgba(-rdata) for i in range(1,11): pnode = self.elecTree.getNode('\electrons::top.tci.results:nl_%02d' % i) ax.plot(pnode.dim_of().data(), pnode.data(), c = rhoColor[i-1]) ax.set_ylabel('ne')
def cm2colors(N, cmap='autumn'): """Takes N evenly spaced colors out of cmap and returns a list of rgb values""" values = range(N) cNorm = Normalize(vmin=0, vmax=N-1) scalarMap = ScalarMappable(norm=cNorm, cmap=cmap) colors = [] for i in xrange(N): colors.append(scalarMap.to_rgba(values[i])) return colors
def _update_scatter(self): # Updates the scatter points for changes in *tidx* or *xidx*. This # just changes the face color # # CALL THIS AFTER *_update_image* # if len(self.data_sets) < 2: return sm = ScalarMappable(norm=self.cbar.norm,cmap=self.cbar.get_cmap()) colors = sm.to_rgba(self.data_sets[1][self.tidx,:,2]) self.scatter.set_facecolors(colors)
class ColorMapper(object): def __init__(self, bottom_val, top_val, color_palette_name): cnorm = mpl_Normalize(vmin=bottom_val, vmax=top_val) comap = get_cmap(color_palette_name) self.scalar_map = ScalarMappable(norm=cnorm, cmap=comap) @staticmethod def rgb_to_hex(rgb): res = tuple([int(c * 255) for c in rgb]) return '#%02x%02x%02x' % res def color_from_val(self, val): return self.rgb_to_hex(self.scalar_map.to_rgba(val)[:3])
def plot_res(self, ax, do_label=True, tag_leafs=False, zlim=False, cmap='jet_r'): from matplotlib.pyplot import get_cmap from matplotlib import patheffects from matplotlib.colors import LogNorm from matplotlib.cm import ScalarMappable # Create a color map using either zlim as given or max/min resolution. cNorm = LogNorm(vmin=self.dx_min, vmax=self.dx_max, clip=True) cMap = ScalarMappable(cmap=get_cmap(cmap), norm=cNorm) dx_vals = {} for key in self.tree: if self[key].isLeaf: color = cMap.to_rgba(self[key].dx) dx_vals[self[key].dx] = 1.0 #if self[key].dx in res_colors: # dx_vals[self[key].dx] = 1.0 # color=#res_colors[self[key].dx] #else: # color='k' self[key].plot_res(ax, fc=color, label=key*tag_leafs) if do_label: ax.annotate('Resolution:', [1.02,0.99], xycoords='axes fraction', color='k',size='medium') for i,key in enumerate(sorted(dx_vals.keys())): #dx_int = log2(key) if key<1: label = '1/%i' % (key**-1) else: label = '%i' % key ax.annotate('%s $R_{E}$'%label, [1.02,0.87-i*0.1], xycoords='axes fraction', color=cMap.to_rgba(key), size='x-large',path_effects=[patheffects.withStroke( linewidth=1,foreground='k')])
def generalized_bar_chart(code_matrix, trans_names, code_names, show_it=True, show_trans_names=False, color_map = "jet", legend_labels = None, title=None, horizontal_grid = True): ldata = {} fig = pylab.figure(facecolor="white", figsize=(12, 4)) fig.subplots_adjust(left=.05, bottom=.15, right=.98, top=.95) code_names = [c for c in range(code_matrix.shape[1])] for i, code in enumerate(range(len(code_names))): ldata[code] = [code_matrix[j, i] for j in range(len(trans_names))] ind = np.arange(len(trans_names)) width = 1.0 / (len(code_names) + 1) traditional_colors = ['red', 'orange', 'yellow', 'green', 'blue', 'purple', 'black', 'grey', 'cyan', 'coral'] ax = fig.add_subplot(111) if title is not None: ax.set_title(title, fontsize=10) if color_map == "AA_traditional_": lcolors = traditional_colors else: cNorm = mpl_Normalize(vmin = 1, vmax = len(code_names)) comap = get_cmap(color_map) scalar_map = ScalarMappable(norm = cNorm, cmap = comap) lcolors = [scalar_map.to_rgba(idx + 1) for idx in range(len(code_names))] the_bars = [] for c in range(len(code_names)): new_bars = ax.bar(ind + (c + .5) * width, ldata[code_names[c]], width, color=lcolors[c % (len(lcolors))]) the_bars.append(new_bars[0]) # bar_groups.append(bars) if show_trans_names: ax.set_xticks(ind + .5) ax.set_xticklabels(trans_names, size="x-small", rotation= -45) else: ax.grid(b = horizontal_grid, which = "major", axis = 'y') ax.set_xticks(ind + .5) ax.set_xticklabels(ind + 1, size="x-small") for i in ind[1:]: ax.axvline(x = i, linestyle = "--", linewidth = .25, color = 'black') if legend_labels != None: fontP =FontProperties() fontP.set_size('small') box = ax.get_position() ax.set_position([box.x0, box.y0, box.width * 0.825, box.height]) # Put a legend to the right of the current axis ax.legend(the_bars, legend_labels, loc='center left', bbox_to_anchor=(1, 0.5), prop = fontP) ax.set_xlim(right=len(trans_names)) if show_it: fig.show() return fig
def plot(countries,values,label='',clim=None,verbose=False): """ Usage: worldmap.plot(countries, values [, label] [, clim]) """ countries_shp = shpreader.natural_earth(resolution='110m',category='cultural', name='admin_0_countries') ## Create a plot fig = plt.figure() ax = plt.axes(projection=ccrs.PlateCarree()) ## Create a colormap cmap = plt.get_cmap('RdYlGn_r') if clim: vmin = clim[0] vmax = clim[1] else: val = values[np.isfinite(values)] mean = val.mean() std = val.std() vmin = mean-2*std vmax = mean+2*std norm = Normalize(vmin=vmin,vmax=vmax) smap = ScalarMappable(norm=norm,cmap=cmap) ax2 = fig.add_axes([0.3, 0.18, 0.4, 0.03]) cbar = ColorbarBase(ax2,cmap=cmap,norm=norm,orientation='horizontal') cbar.set_label(label) ## Add countries to the map for country in shpreader.Reader(countries_shp).records(): countrycode = country.attributes['adm0_a3'] countryname = country.attributes['name_long'] ## Check for country code consistency if countrycode == 'SDS': #South Sudan countrycode = 'SSD' elif countrycode == 'ROU': #Romania countrycode = 'ROM' elif countrycode == 'COD': #Dem. Rep. Congo countrycode = 'ZAR' elif countrycode == 'KOS': #Kosovo countrycode = 'KSV' if countrycode in countries: val = values[countries==countrycode] if np.isfinite(val): color = smap.to_rgba(val) else: color = 'grey' else: color = 'w' if verbose: print("No data available for "+countrycode+": "+countryname) ax.add_geometries(country.geometry,ccrs.PlateCarree(),facecolor=color,label=countryname) plt.show()
def generate_colors(desired_palette, num_desired_colors): """ Generate an array of color strings, interpolated from the desired_palette. desired_palette is from palettable Conceptually, this takes a list of colors, and lets you generate any length of colors from that array. """ cmap = desired_palette.mpl_colormap mappable = ScalarMappable(norm=Normalize(vmin=0, vmax=1), cmap=cmap) cols = [] for i in range(1,num_desired_colors+1): (r,g,b,a) = mappable.to_rgba((i - 1) / num_desired_colors) cols.append(rgb_to_hex(map(lambda x: int(x*255), [r,g,b]))) return cols
def _init_scatter(self): # Plots the scatter points at the base of each vector showing the # vertical deformation for the second data set. If there is only # one data set then this function does nothing. # # CALL THIS AFTER *_init_image* # if len(self.data_sets) < 2: self.scatter = None return sm = ScalarMappable(norm=self.cbar.norm,cmap=self.cbar.get_cmap()) # use scatter points to show z for second data set colors = sm.to_rgba(self.data_sets[1][self.tidx,:,2]) self.scatter = self.map_ax.scatter(self.x[:,0],self.x[:,1], c=colors,s=self.scatter_size, zorder=2,edgecolor=self.colors[1])
def __plot_variance(self): pointsMin = self.__calc_min() pointsMax = self.__calc_max() polys = [] variance = [] varMin = 1000 varMax = 0 lastX = None lastYMin = None lastYMax = None for x in pointsMin.iterkeys(): if lastX is None: lastX = x if lastYMin is None: lastYMin = pointsMin[x] if lastYMax is None: lastYMax = pointsMax[x] polys.append([[x, pointsMin[x]], [x, pointsMax[x]], [lastX, lastYMax], [lastX, lastYMin], [x, pointsMin[x]]]) lastX = x lastYMin = pointsMin[x] lastYMax = pointsMax[x] var = pointsMax[x] - pointsMin[x] variance.append(var) varMin = min(varMin, var) varMax = max(varMax, var) norm = Normalize(vmin=varMin, vmax=varMax) sm = ScalarMappable(norm, self.colourMap) colours = sm.to_rgba(variance) pc = PolyCollection(polys) pc.set_gid('plot') pc.set_norm(norm) pc.set_color(colours) self.axes.add_collection(pc) return None, None
def pyplot_bar(y, cmap='Blues'): """ Make a good looking pylot bar plot. Use a colormap to color the bars. y: height of bars cmap: colormap, defaults to 'Blues' """ import matplotlib.pyplot as plt from matplotlib.colors import Normalize from matplotlib.cm import ScalarMappable vmax = numpy.max(y) vmin = (numpy.min(y)*3. - vmax)/2. colormap = ScalarMappable(norm=Normalize(vmin, vmax), cmap='Blues') plt.bar(numpy.arange(len(y)), y, color=colormap.to_rgba(y), align='edge', width=1.0)
def make_plot(self, val, cct, clm, cmap = 'jet', cmap_norm = None, cmap_vmin = None, cmap_vmax = None): # make color mapping smap = ScalarMappable(cmap_norm, cmap) smap.set_clim(cmap_vmin, cmap_vmax) smap.set_array(val) bin_colors = smap.to_rgba(val) # make patches patches = [] for i_c, i_clm in enumerate(clm): patches.append(Rectangle((i_clm[0], i_clm[2]), i_clm[1] - i_clm[0], i_clm[3] - i_clm[2])) patches_colle = PatchCollection(patches) patches_colle.set_edgecolor('face') patches_colle.set_facecolor(bin_colors) return patches_colle, smap
def plotThomsonEdgeData(self, ax): proNode = self.elecTree.getNode('\ELECTRONS::TOP.YAG_EDGETS.RESULTS:NE') rhoNode = self.elecTree.getNode('\ELECTRONS::TOP.YAG_EDGETS.RESULTS:RMID') rpro = proNode.data() rrho = rhoNode.data() rtime = proNode.dim_of().data() goodTimes = rrho[0] > 0 pro = rpro[:,goodTimes] rho = rrho[:,goodTimes] time = rtime[goodTimes] sm = ScalarMappable() rhoColor = sm.to_rgba(-rho) for i in range(rpro.shape[0]): ax.plot(time, pro[i], c=np.mean(rhoColor[i],axis=0)) ax.set_ylabel('ne')
def multiTimeTrace(self, i, time, rhoFlat, data, ylabel, reverse=True): if reverse: sm = ScalarMappable(cmap='gist_rainbow') else: sm = ScalarMappable(cmap='gist_rainbow_r') rhoColor = sm.to_rgba(rhoFlat) for j in range(len(rhoFlat)): if len(time.shape) > 1: if (data.shape[0] < data.shape[1]): self.axes[i].plot(time[j,:], data[j,:], c=rhoColor[j], linestyle='-', marker='.') else: self.axes[i].plot(time[:,j], data[:,j], c=rhoColor[j], linestyle='-', marker='.') else: if (data.shape[0] < data.shape[1]): self.axes[i].plot(time, data[j,:], c=rhoColor[j], linestyle='-', marker='.') else: self.axes[i].plot(time, data[:,j], c=rhoColor[j], linestyle='-', marker='.') self.axes[i].set_ylabel(ylabel)
def plot2D(X, filename=None, last_column_color=False): x1 = X[:, 0] x2 = X[:, 1] m = X.shape[0] if last_column_color: c = X[:, -1] c_map = get_cmap('jet') c_norm = Normalize() c_norm.autoscale(c) scalar_map = ScalarMappable(norm=c_norm, cmap=c_map) color_val = scalar_map.to_rgba(c) else: color_val = 'b' * m fig = figure() ax = fig.add_subplot(111) for i in range(m): ax.plot(x1[i], x2[i], 'o', color=color_val[i]) if filename is None: fig.show() else: fig.savefig(filename + ".png") fig.clf() close()
def plot_event(catalog, projection='cyl', resolution='l', continent_fill_color='0.9', water_fill_color='white', label= None, color='depth', pretty = False, colormap=None, llat = -90, ulat = 90, llon = -180, ulon = 180, figsize=(16,24), par_range = (-90., 120., 30.), mer_range = (0., 360., 60.), showHour = False, M_above = 0.0, location = 'World', **kwargs): # @UnusedVariable """ Creates preview map of all events in current Catalog object. :type projection: str, optional :param projection: The map projection. Currently supported are * ``"cyl"`` (Will plot the whole world.) * ``"ortho"`` (Will center around the mean lat/long.) * ``"local"`` (Will plot around local events) Defaults to "cyl" :type resolution: str, optional :param resolution: Resolution of the boundary database to use. Will be based directly to the basemap module. Possible values are * ``"c"`` (crude) * ``"l"`` (low) * ``"i"`` (intermediate) * ``"h"`` (high) * ``"f"`` (full) Defaults to ``"l"`` :type continent_fill_color: Valid matplotlib color, optional :param continent_fill_color: Color of the continents. Defaults to ``"0.9"`` which is a light gray. :type water_fill_color: Valid matplotlib color, optional :param water_fill_color: Color of all water bodies. Defaults to ``"white"``. :type label: str, optional :param label:Events will be labeld based on the chosen property. Possible values are * ``"magnitude"`` * ``None`` Defaults to ``"magnitude"`` :type color: str, optional :param color:The events will be color-coded based on the chosen proberty. Possible values are * ``"date"`` * ``"depth"`` Defaults to ``"depth"`` :type colormap: str, optional, any matplotlib colormap :param colormap: The colormap for color-coding the events. The event with the smallest property will have the color of one end of the colormap and the event with the biggest property the color of the other end with all other events in between. Defaults to None which will use the default colormap for the date encoding and a colormap going from green over yellow to red for the depth encoding. .. rubric:: Example >>> cat = readEvents( \ "http://www.seismicportal.eu/services/event/search?magMin=8.0") \ # doctest:+SKIP >>> cat.plot() # doctest:+SKIP """ from mpl_toolkits.basemap import Basemap import matplotlib.pyplot as plt from matplotlib.colors import Normalize from matplotlib.cm import ScalarMappable import matplotlib as mpl if color not in ('date', 'depth'): raise ValueError('Events can be color coded by date or depth. ' "'%s' is not supported." % (color,)) if label not in (None, 'magnitude', 'depth'): raise ValueError('Events can be labeled by magnitude or events can' ' not be labeled. ' "'%s' is not supported." % (label,)) if location == 'US': llon=-125 llat=20 ulon=-60 ulat=60 lat_0=38 lon_0=-122.0 par_range = (20, 61, 10) mer_range = (-120, -59, 20) elif location == 'CA': llat = '30' ulat = '45' llon = '-130' ulon = '-110' lat_0=38 lon_0=-122.0 par_range = (30, 46, 5) mer_range = (-130, -109, 10) else: lat_0=0 lon_0=0 lats, lons, mags, times, labels, colors = get_event_info(catalog, M_above, llat, ulat, llon, ulon, color, label) min_color = min(colors) max_color = max(colors) # Create the colormap for date based plotting. if colormap is None: if color == "date": colormap = plt.get_cmap() else: # Choose green->yellow->red for the depth encoding. colormap = plt.get_cmap("RdYlGn_r") scal_map = ScalarMappable(norm=Normalize(min_color, max_color), cmap=colormap) scal_map.set_array(np.linspace(0, 1, 1)) fig = plt.figure(figsize = figsize) # The colorbar should only be plotted if more then one event is # present. if len(catalog) > 1: map_ax = fig.add_axes([0.03, 0.13, 0.94, 0.82]) #cm_ax = fig.add_axes([0.03, 0.05, 0.94, 0.05]) #rect = [left, bottom, width, height] cm_ax = fig.add_axes([0.98, 0.39, 0.04, 0.3]) plt.sca(map_ax) else: map_ax = fig.add_axes([0.05, 0.05, 0.90, 0.90]) if projection == 'cyl': map = Basemap(resolution=resolution, lat_0 = lat_0, lon_0 = lon_0, llcrnrlon=llon,llcrnrlat=llat,urcrnrlon=ulon,urcrnrlat=ulat) elif projection == 'ortho': map = Basemap(projection='ortho', resolution=resolution, area_thresh=1000.0, lat_0=sum(lats) / len(lats), lon_0=sum(lons) / len(lons)) elif projection == 'local': if min(lons) < -150 and max(lons) > 150: max_lons = max(np.array(lons) % 360) min_lons = min(np.array(lons) % 360) else: max_lons = max(lons) min_lons = min(lons) lat_0 = (max(lats) + min(lats)) / 2. lon_0 = (max_lons + min_lons) / 2. if lon_0 > 180: lon_0 -= 360 deg2m_lat = 2 * np.pi * 6371 * 1000 / 360 deg2m_lon = deg2m_lat * np.cos(lat_0 / 180 * np.pi) if len(lats) > 1: height = (max(lats) - min(lats)) * deg2m_lat width = (max_lons - min_lons) * deg2m_lon margin = 0.2 * (width + height) height += margin width += margin else: height = 2.0 * deg2m_lat width = 5.0 * deg2m_lon map = Basemap(projection='aeqd', resolution=resolution, area_thresh=1000.0, lat_0=lat_0, lon_0=lon_0, width=width, height=height) # not most elegant way to calculate some round lats/lons def linspace2(val1, val2, N): """ returns around N 'nice' values between val1 and val2 """ dval = val2 - val1 round_pos = int(round(-np.log10(1. * dval / N))) delta = round(2. * dval / N, round_pos) / 2 new_val1 = np.ceil(val1 / delta) * delta new_val2 = np.floor(val2 / delta) * delta N = (new_val2 - new_val1) / delta + 1 return np.linspace(new_val1, new_val2, N) N1 = int(np.ceil(height / max(width, height) * 8)) N2 = int(np.ceil(width / max(width, height) * 8)) map.drawparallels(linspace2(lat_0 - height / 2 / deg2m_lat, lat_0 + height / 2 / deg2m_lat, N1), labels=[0, 1, 1, 0]) if min(lons) < -150 and max(lons) > 150: lon_0 %= 360 meridians = linspace2(lon_0 - width / 2 / deg2m_lon, lon_0 + width / 2 / deg2m_lon, N2) meridians[meridians > 180] -= 360 map.drawmeridians(meridians, labels=[1, 0, 0, 1]) else: msg = "Projection %s not supported." % projection raise ValueError(msg) # draw coast lines, country boundaries, fill continents. map.drawcoastlines(color="0.4") map.drawcountries(color="0.75") if location == 'CA' or location == 'US': map.drawstates(color="0.75") # draw lat/lon grid lines map.drawparallels(np.arange(par_range[0], par_range[1], par_range[2]), labels=[1,0,0,0], linewidth=0) map.drawmeridians(np.arange(mer_range[0],mer_range[1], mer_range[2]), labels=[0,0,0,1], linewidth=0) if pretty: map.etopo() else: map.drawmapboundary(fill_color=water_fill_color) map.fillcontinents(color=continent_fill_color, lake_color=water_fill_color) # compute the native map projection coordinates for events. x, y = map(lons, lats) # plot labels if 100 > len(mags) > 1: for name, xpt, ypt, colorpt in zip(labels, x, y, colors): # Check if the point can actually be seen with the current map # projection. The map object will set the coordinates to very # large values if it cannot project a point. if xpt > 1e25: continue plt.text(xpt, ypt, name, weight="heavy", color=scal_map.to_rgba(colorpt)) elif len(mags) == 1: plt.text(x[0], y[0], labels[0], weight="heavy", color="red") min_size = 6 max_size = 30 min_mag = min(mags) max_mag = max(mags) if len(mags) > 1: frac = [(_i - min_mag) / (max_mag - min_mag) for _i in mags] magnitude_size = [(_i * (max_size - min_size)) ** 2 for _i in frac] #magnitude_size = [(_i * min_size) for _i in mags] #print magnitude_size colors_plot = [scal_map.to_rgba(c) for c in colors] else: magnitude_size = 15.0 ** 2 colors_plot = "red" map.scatter(x, y, marker='o', s=magnitude_size, c=colors_plot, zorder=10) if len(mags) > 1: plt.title( "{event_count} events ({start} to {end}) " "- Color codes {colorcode}, size the magnitude".format( event_count=len(lats), start=min(times).strftime("%Y-%m-%d"), end=max(times).strftime("%Y-%m-%d"), colorcode="origin time" if color == "date" else "depth")) else: plt.title("Event at %s" % times[0].strftime("%Y-%m-%d")) # Only show the colorbar for more than one event. if len(mags) > 1: cb = mpl.colorbar.ColorbarBase(ax=cm_ax, cmap=colormap, orientation='vertical') cb.set_ticks([0, 0.25, 0.5, 0.75, 1.0]) color_range = max_color - min_color if showHour: cb.set_ticklabels([ _i.strftime('%Y-%b-%d, %H:%M:%S %p') if color == "date" else '%.1fkm' % (_i) for _i in [min_color, min_color + color_range * 0.25, min_color + color_range * 0.50, min_color + color_range * 0.75, max_color]]) else: cb.set_ticklabels([_i.strftime('%Y-%b-%d') if color == "date" else '%.1fkm' % (_i) for _i in [min_color, min_color + color_range * 0.25, min_color + color_range * 0.50, min_color + color_range * 0.75, max_color]]) plt.show()
def plot_mt(earthquakes, mt, event_id, location = None, M_above = 5.0, show_above_M = True, llat = '-90', ulat = '90', llon = '-170', ulon = '190', figsize = (12,8), radius = 25, dist_bt = 600, mt_width = 2, angle_step = 20, show_eq = True, par_range = (-90., 120., 30.), mer_range = (0, 360, 60), pretty = False, legend_loc = 4, title = '', resolution = 'l'): ''' Function to plot moment tensors on the map Input: earthquakes - list of earthquake information mt - list of focal/moment_tensor information event_id - event ID corresponding to the earthquakes location - predefined region, choose from 'US' or 'CA', default is 'None' which will plot the whole world M_above - Only show the events with magnitude larger than this number default is 5.0, use with show_above_M show_above_M - Flag to turn on the M_above option, default is True, llat - bottom left corner latitude, default is -90 ulat - upper right corner latitude, default is 90 llon - bottom left corner longitude, default is -170 ulon - upper right corner longitude, default is 190 figsize - figure size, default is (12,8) radius - used in checking collisions (MT), put the MT on a circle with this radius, default is 25 dist_bt - used in checking collisions (MT), if two events within dist_bt km, then we say it is a collision, default is 600 angle_step - used in checking collisions (MT), this is to decide the angle step on the circle, default is 20 degree mt_width - size of the MT on the map. Different scale of the map may need different size, play with it. show_eq - flag to show the seismicity as well, default is True par_range - range of latitudes you want to label on the map, start lat, end lat and step size, default is (-90., 120., 30.), mer_range - range of longitudes you want to label on the map, start lon, end lon and step size, default is (0, 360, 60), pretty - draw a pretty map, default is False to make faster plot legend_loc - location of the legend, default is 4 title - title of the plot resolution - resolution of the map, Possible values are * ``"c"`` (crude) * ``"l"`` (low) * ``"i"`` (intermediate) * ``"h"`` (high) * ``"f"`` (full) Defaults to ``"l"`` ''' if location == 'US': llon=-125 llat=20 ulon=-70 ulat=60 M_above = 4.0 radius = 5 dist_bt = 200 par_range = (20, 60, 15) mer_range = (-120, -60, 15) mt_width = 0.8 drawCountries = True elif location == 'CAL': llat = '30' ulat = '45' llon = '-130' ulon = '-110' M_above = 3.0 radius = 1.5 dist_bt = 50 mt_width = 0.3 drawStates = True else: location = None print earthquakes,mt,event_id times = [event[6] for event in earthquakes] if show_above_M: mags = [row[3] for row in mt] index = np.array(mags) >= M_above mt_select = np.array(mt)[index] evid = np.array(event_id)[index] times_select = np.array(times)[index] else: evid = [row[0] for row in event_id] times_select = times mt_select = mt lats = [row[0] for row in mt_select] lons = [row[1] for row in mt_select] depths = [row[2] for row in mt_select] mags = [row[3] for row in mt_select] focmecs = [row[4:] for row in mt_select] lats_m, lons_m, indicator = check_collision(lats, lons, radius, dist_bt, angle_step) count = 0 colors=[] min_color = min(times_select) max_color = max(times_select) colormap = plt.get_cmap() for i in times_select: colors.append(i) scal_map = ScalarMappable(norm=cc.Normalize(min_color, max_color),cmap=colormap) scal_map.set_array(np.linspace(0, 1, 1)) colors_plot = [scal_map.to_rgba(c) for c in colors] ys = np.array(lats_m) xs = np.array(lons_m) url = ['http://earthquake.usgs.gov/earthquakes/eventpage/' + tmp + '#summary' for tmp in evid] stnm = np.array(evid) fig, ax1 = plt.subplots(1,1, figsize = figsize) #map_ax = fig.add_axes([0.03, 0.13, 0.94, 0.82]) if show_eq: cm_ax = fig.add_axes([0.98, 0.39, 0.04, 0.3]) plt.sca(ax1) cb = mpl.colorbar.ColorbarBase(ax=cm_ax, cmap=colormap, orientation='vertical') cb.set_ticks([0, 0.25, 0.5, 0.75, 1.0]) color_range = max_color - min_color cb.set_ticklabels([_i.strftime('%Y-%b-%d, %H:%M:%S %p') for _i in [min_color, min_color + color_range * 0.25, min_color + color_range * 0.50, min_color + color_range * 0.75, max_color]]) m = Basemap(projection='cyl', lon_0=142.36929, lat_0=38.3215, llcrnrlon=llon,llcrnrlat=llat,urcrnrlon=ulon,urcrnrlat=ulat,resolution=resolution) m.drawcoastlines() m.drawmapboundary() m.drawcountries() m.drawparallels(np.arange(par_range[0], par_range[1], par_range[2]), labels=[1,0,0,0], linewidth=0) m.drawmeridians(np.arange(mer_range[0],mer_range[1], mer_range[2]), labels=[0,0,0,1], linewidth=0) if pretty: m.etopo() else: m.fillcontinents() x, y = m(lons_m, lats_m) for i in range(len(focmecs)): index = np.where(focmecs[i] == 0)[0] #note here, if the mrr is zero, then you will have an error #so, change this to a very small number if focmecs[i][0] == 0: focmecs[i][0] = 0.001 width = mags[i] * mt_width if depths[i] <= 50: color = '#FFA500' #label_ elif depths[i] > 50 and depths [i] <= 100: color = '#FFFF00' elif depths[i] > 100 and depths [i] <= 150: color = '#00FF00' elif depths[i] > 150 and depths [i] <= 200: color = 'b' else: color = 'r' if indicator[i] == 1: m.plot([lons[i],lons_m[i]],[lats[i], lats_m[i]], 'k') #m.plot([10,20],[0,0]) try: b = Beach(focmecs[i], xy=(x[i], y[i]),width=width, linewidth=1, facecolor= color, alpha=1) count += 1 line, = ax1.plot(x[i],y[i], 'o', picker=5, markersize=30, alpha =0) except: pass b.set_zorder(3) ax1.add_collection(b) d=5 circ1 = Line2D([0], [0], linestyle="none", marker="o", alpha=0.6, markersize=10, markerfacecolor="#FFA500") circ2 = Line2D([0], [0], linestyle="none", marker="o", alpha=0.6, markersize=10, markerfacecolor="#FFFF00") circ3 = Line2D([0], [0], linestyle="none", marker="o", alpha=0.6, markersize=10, markerfacecolor="#00FF00") circ4 = Line2D([0], [0], linestyle="none", marker="o", alpha=0.6, markersize=10, markerfacecolor="b") circ5 = Line2D([0], [0], linestyle="none", marker="o", alpha=0.6, markersize=10, markerfacecolor="r") M4 = Line2D([0], [0], linestyle="none", marker="o", alpha=0.4, markersize= 4*d, markerfacecolor="k") M5 = Line2D([0], [0], linestyle="none", marker="o", alpha=0.4, markersize= 5*d, markerfacecolor="k") M6 = Line2D([0], [0], linestyle="none", marker="o", alpha=0.4, markersize= 6*d, markerfacecolor="k") M7 = Line2D([0], [0], linestyle="none", marker="o", alpha=0.4, markersize= 7*d, markerfacecolor="k") if location == 'World': title = str(count) + ' events with focal mechanism - color codes depth, size the magnitude' elif location == 'US': title = 'US events with focal mechanism - color codes depth, size the magnitude' elif location == 'CAL': title = 'California events with focal mechanism - color codes depth, size the magnitude' elif location is None: pass legend1 = plt.legend((circ1, circ2, circ3, circ4, circ5), ("depth $\leq$ 50 km", "50 km $<$ depth $\leq$ 100 km", "100 km $<$ depth $\leq$ 150 km", "150 km $<$ depth $\leq$ 200 km","200 km $<$ depth"), numpoints=1, loc=legend_loc) plt.title(title) plt.gca().add_artist(legend1) if location == 'World': plt.legend((M4,M5,M6,M7), ("M 4.0", "M 5.0", "M 6.0", "M 7.0"), numpoints=1, loc=legend_loc) x, y = m(lons, lats) min_size = 6 max_size = 30 min_mag = min(mags) max_mag = max(mags) if show_eq: if len(lats) > 1: frac = [(_i - min_mag) / (max_mag - min_mag) for _i in mags] magnitude_size = [(_i * (max_size - min_size)) ** 2 for _i in frac] magnitude_size = [(_i * min_size/2)**2 for _i in mags] else: magnitude_size = 15.0 ** 2 colors_plot = "red" m.scatter(x, y, marker='o', s=magnitude_size, c=colors_plot, zorder=10) plt.show() print 'Max magnitude ' + str(np.max(mags)), 'Min magnitude ' + str(np.min(mags))
def _get_2d_plot(self, label_stable=True, label_unstable=True, ordering=None, energy_colormap=None, vmin_mev=-60.0, vmax_mev=60.0, show_colorbar=True, process_attributes=False): """ Shows the plot using pylab. Usually I won't do imports in methods, but since plotting is a fairly expensive library to load and not all machines have matplotlib installed, I have done it this way. """ plt = get_publication_quality_plot(8, 6) from matplotlib.font_manager import FontProperties if ordering is None: (lines, labels, unstable) = self.pd_plot_data else: (_lines, _labels, _unstable) = self.pd_plot_data (lines, labels, unstable) = order_phase_diagram( _lines, _labels, _unstable, ordering) if energy_colormap is None: if process_attributes: for x, y in lines: plt.plot(x, y, "k-", linewidth=3, markeredgecolor="k") # One should think about a clever way to have "complex" # attributes with complex processing options but with a clear # logic. At this moment, I just use the attributes to know # whether an entry is a new compound or an existing (from the # ICSD or from the MP) one. for x, y in labels.keys(): if labels[(x, y)].attribute is None or \ labels[(x, y)].attribute == "existing": plt.plot(x, y, "ko", linewidth=3, markeredgecolor="k", markerfacecolor="b", markersize=12) else: plt.plot(x, y, "k*", linewidth=3, markeredgecolor="k", markerfacecolor="g", markersize=18) else: for x, y in lines: plt.plot(x, y, "ko-", linewidth=3, markeredgecolor="k", markerfacecolor="b", markersize=15) else: from matplotlib.colors import Normalize, LinearSegmentedColormap from matplotlib.cm import ScalarMappable pda = PDAnalyzer(self._pd) for x, y in lines: plt.plot(x, y, "k-", linewidth=3, markeredgecolor="k") vmin = vmin_mev / 1000.0 vmax = vmax_mev / 1000.0 if energy_colormap == 'default': mid = - vmin / (vmax - vmin) cmap = LinearSegmentedColormap.from_list( 'my_colormap', [(0.0, '#005500'), (mid, '#55FF55'), (mid, '#FFAAAA'), (1.0, '#FF0000')]) else: cmap = energy_colormap norm = Normalize(vmin=vmin, vmax=vmax) _map = ScalarMappable(norm=norm, cmap=cmap) _energies = [pda.get_equilibrium_reaction_energy(entry) for coord, entry in labels.items()] energies = [en if en < 0.0 else -0.00000001 for en in _energies] vals_stable = _map.to_rgba(energies) ii = 0 if process_attributes: for x, y in labels.keys(): if labels[(x, y)].attribute is None or \ labels[(x, y)].attribute == "existing": plt.plot(x, y, "o", markerfacecolor=vals_stable[ii], markersize=12) else: plt.plot(x, y, "*", markerfacecolor=vals_stable[ii], markersize=18) ii += 1 else: for x, y in labels.keys(): plt.plot(x, y, "o", markerfacecolor=vals_stable[ii], markersize=15) ii += 1 font = FontProperties() font.set_weight("bold") font.set_size(24) # Sets a nice layout depending on the type of PD. Also defines a # "center" for the PD, which then allows the annotations to be spread # out in a nice manner. if len(self._pd.elements) == 3: plt.axis("equal") plt.xlim((-0.1, 1.2)) plt.ylim((-0.1, 1.0)) plt.axis("off") center = (0.5, math.sqrt(3) / 6) else: all_coords = labels.keys() miny = min([c[1] for c in all_coords]) ybuffer = max(abs(miny) * 0.1, 0.1) plt.xlim((-0.1, 1.1)) plt.ylim((miny - ybuffer, ybuffer)) center = (0.5, miny / 2) plt.xlabel("Fraction", fontsize=28, fontweight='bold') plt.ylabel("Formation energy (eV/fu)", fontsize=28, fontweight='bold') for coords in sorted(labels.keys(), key=lambda x: -x[1]): entry = labels[coords] label = entry.name # The follow defines an offset for the annotation text emanating # from the center of the PD. Results in fairly nice layouts for the # most part. vec = (np.array(coords) - center) vec = vec / np.linalg.norm(vec) * 10 if np.linalg.norm(vec) != 0 \ else vec valign = "bottom" if vec[1] > 0 else "top" if vec[0] < -0.01: halign = "right" elif vec[0] > 0.01: halign = "left" else: halign = "center" if label_stable: if process_attributes and entry.attribute == 'new': plt.annotate(latexify(label), coords, xytext=vec, textcoords="offset points", horizontalalignment=halign, verticalalignment=valign, fontproperties=font, color='g') else: plt.annotate(latexify(label), coords, xytext=vec, textcoords="offset points", horizontalalignment=halign, verticalalignment=valign, fontproperties=font) if self.show_unstable: font = FontProperties() font.set_size(16) pda = PDAnalyzer(self._pd) energies_unstable = [pda.get_e_above_hull(entry) for entry, coord in unstable.items()] if energy_colormap is not None: energies.extend(energies_unstable) vals_unstable = _map.to_rgba(energies_unstable) ii = 0 for entry, coords in unstable.items(): vec = (np.array(coords) - center) vec = vec / np.linalg.norm(vec) * 10 \ if np.linalg.norm(vec) != 0 else vec label = entry.name if energy_colormap is None: plt.plot(coords[0], coords[1], "ks", linewidth=3, markeredgecolor="k", markerfacecolor="r", markersize=8) else: plt.plot(coords[0], coords[1], "s", linewidth=3, markeredgecolor="k", markerfacecolor=vals_unstable[ii], markersize=8) if label_unstable: plt.annotate(latexify(label), coords, xytext=vec, textcoords="offset points", horizontalalignment=halign, color="b", verticalalignment=valign, fontproperties=font) ii += 1 if energy_colormap is not None and show_colorbar: _map.set_array(energies) cbar = plt.colorbar(_map) cbar.set_label( 'Energy [meV/at] above hull (in red)\nInverse energy [' 'meV/at] above hull (in green)', rotation=-90, ha='left', va='center') ticks = cbar.ax.get_yticklabels() cbar.ax.set_yticklabels(['${v}$'.format( v=float(t.get_text().strip('$'))*1000.0) for t in ticks]) f = plt.gcf() f.set_size_inches((8, 6)) plt.subplots_adjust(left=0.09, right=0.98, top=0.98, bottom=0.07) return plt
def hm_to_rgb(R, X = None, scaling = 3, shape = (), sigma = 2, cmap = 'jet', normalize = True): ''' Takes as input an intensity array and produces a rgb image for the represented heatmap. optionally draws the outline of another input on top of it. Parameters ---------- R : numpy.ndarray the heatmap to be visualized, shaped [M x N] X : numpy.ndarray optional. some input, usually the data point for which the heatmap R is for, which shall serve as a template for a black outline to be drawn on top of the image shaped [M x N] scaling: int factor, on how to enlarge the heatmap (to control resolution and as a inverse way to control outline thickness) after reshaping it using shape. shape: tuple or list, length = 2 optional. if not given, X is reshaped to be square. sigma : double optional. sigma-parameter for the canny algorithm used for edge detection. the found edges are drawn as outlines. cmap : str optional. color map of choice normalize : bool optional. whether to normalize the heatmap to [-1 1] prior to colorization or not. Returns ------- rgbimg : numpy.ndarray three-dimensional array of shape [scaling*H x scaling*W x 3] , where H*W == M*N ''' sm = ScalarMappable(cmap = cmap) #prepare heatmap -> rgb image conversion if normalize: R = R / np.max(np.abs(R)) R[0,0] = 1; R[-1,-1] = -1; # anchors for controlled color mapping, to be drawn over later. R = enlarge_image(vec2im(R,shape), scaling) rgbimg = sm.to_rgba(R)[:,:,0:3] rgbimg = repaint_corner_pixels(rgbimg, scaling) if not X is None: #compute the outline of the input X = enlarge_image(vec2im(X,shape), scaling) xdims = X.shape Rdims = R.shape if not np.all(xdims == Rdims): print ('transformed heatmap and data dimension mismatch. data dimensions differ?') print ('R.shape = {0} X.shape = {1}'.format(Rdims, xdims)) print ('skipping drawing of outline\n') else: edges = feature.canny(X, sigma=sigma) edges = np.invert(np.dstack([edges]*3))*1.0 rgbimg *= edges # set outline pixels to black color return rgbimg
class SpiroGraph(object): ''' Spirograph drawer with matplotlib slider widgets to change parameters. Parameters of line are: R: The radius of the big circle r: The radius of the small circle which rolls along the inside of the bigger circle p: distance from centre of smaller circle to point in the circle where the pen hole is. tmax: the angle through which the smaller circle is rotated to draw the spirograph tstep: how often matplotlib plots a point a, b, c: parameters of the linewidth equation. ''' # kwargs for each of the matplotlib sliders slider_kwargs = ( {'label': 't_max', 'valmin': np.pi, 'valmax': 200 * np.pi, 'valinit': tmax0, 'valfmt': PiString()}, {'label': 't_step', 'valmin': 0.01, 'valmax': 10, 'valinit': tstep0}, {'label': 'R', 'valmin': 1, 'valmax': 200, 'valinit': R0}, {'label': 'r', 'valmin': 1, 'valmax': 200, 'valinit': r0}, {'label': 'p', 'valmin': 1, 'valmax': 200, 'valinit': p0}, {'label': 'colour', 'valmin': 0, 'valmax': 1, 'valinit': 1}, {'label': 'width_a', 'valmin': 0.5, 'valmax': 10, 'valinit': 1}, {'label': 'width_b', 'valmin': 0, 'valmax': 10, 'valinit': 0}, {'label': 'width_c', 'valmin': 0, 'valmax': 10, 'valinit': 0.5}) rbutton_kwargs = ( {'labels': ('black', 'white'), 'activecolor': 'white', 'active': 0}, {'labels': ('solid', 'variable'), 'activecolor': 'white', 'active': 0}) def __init__(self, colormap, figsize=(7, 10)): self.colormap_name = colormap self.variable_color = False # Use ScalarMappable to map full colormap to range 0 - 1 self.colormap = ScalarMappable(cmap=colormap) self.colormap.set_clim(0, 1) # set up main axis onto which to draw spirograph self.figsize = figsize plt.rcParams['figure.figsize'] = figsize self.fig, self.mainax = plt.subplots() plt.subplots_adjust(bottom=0.3) title = self.mainax.set_title('Spirograph Drawer!', size=20, color='white') self.text = [title, ] # set up slider axes self.slider_axes = [plt.axes([0.25, x, 0.65, 0.015]) for x in np.arange(0.05, 0.275, 0.025)] # same again for radio buttons self.rbutton_axes = [plt.axes([0.025, x, 0.1, 0.15]) for x in np.arange(0.02, 0.302, 0.15)] # use log scale for tstep slider self.slider_axes[1].set_xscale('log') # turn off frame, ticks and tick labels for all axes for ax in chain(self.slider_axes, self.rbutton_axes, [self.mainax, ]): ax.axis('off') # use axes and kwargs to create list of sliders/rbuttons self.sliders = [Slider(ax, **kwargs) for ax, kwargs in zip(self.slider_axes, self.slider_kwargs)] self.rbuttons = [RadioButtons(ax, **kwargs) for ax, kwargs in zip(self.rbutton_axes, self.rbutton_kwargs)] self.update_figcolors() # set up initial line self.t = np.arange(0, tmax0, tstep0) x, y = spiro_linefunc(self.t, R0, r0, p0) self.linecollection = LineCollection( segments(x, y), linewidths=spiro_linewidths(self.t, a0, b0, c0), color=self.colormap.to_rgba(col0)) self.mainax.add_collection(self.linecollection) # creates the plot and connects sliders to various update functions self.run() def update_figcolors(self, bgcolor='black'): ''' function run by background color radiobutton. Sets all labels, text, and sliders to foreground color, all axes to background color ''' fgcolor = 'white' if bgcolor == 'black' else 'black' self.fig.set_facecolor(bgcolor) self.mainax.set_axis_bgcolor(bgcolor) for ax in chain(self.slider_axes, self.rbutton_axes): ax.set_axis_bgcolor(bgcolor) # set fgcolor elements to black or white, mostly elements of sliders for item in chain(map(attrgetter('label'), self.sliders), map(attrgetter('valtext'), self.sliders), map(attrgetter('poly'), self.sliders), self.text, *map(attrgetter('labels'), self.rbuttons)): item.set_color(fgcolor) self.update_radiobutton_colors() plt.draw() def update_linewidths(self, *args): ''' function run by a, b and c parameter sliders. Sets width of each line in linecollection according to sine function ''' a, b, c = (s.val for s in self.sliders[6:]) self.linecollection.set_linewidths(spiro_linewidths(self.t, a, b, c)) plt.draw() def update_linecolors(self, *args): ''' function run by color slider and indirectly by variable/solid color radiobutton. Updates colors of each line in linecollection using the set colormap. ''' # get current color value (a value between 1 and 0) col_val = self.sliders[5].val if not self.variable_color: # if solid color, convert color value to rgb and set the color self.linecollection.set_color(self.colormap.to_rgba(col_val)) else: # create values between 0 and 1 for each line segment colors = (self.t / max(self.t)) + col_val # use color value to roll colors colors[colors > 1] -= 1 self.linecollection.set_color( [self.colormap.to_rgba(i) for i in colors]) plt.draw() def update_lineverts(self, *args): ''' function run by R, r, p, tmax and tstep sliders to update line vertices ''' tmax, tstep, R, r, p = (s.val for s in self.sliders[:5]) self.t = np.arange(0, tmax, tstep) x, y = spiro_linefunc(self.t, R, r, p) self.linecollection.set_verts(segments(x, y)) # change axis limits to pad new line nicely self.mainax.set(xlim=(min(x) - 5, max(x) + 5), ylim=(min(y) - 5, max(y) + 5)) plt.draw() def update_linecolor_setting(self, val): ''' function run by solid/variable colour slider, alters variable_color attribute then calls update_linecolors ''' if val == 'variable': self.variable_color = True elif val == 'solid': self.variable_color = False # need to update radiobutton colors here. self.update_radiobutton_colors() self.update_linecolors() def update_radiobutton_colors(self): ''' makes radiobutton colors correct even on a changing axis background ''' bgcolor = self.rbuttons[0].value_selected fgcolor = 'white' if bgcolor == 'black' else 'black' for i, b in enumerate(self.rbuttons): # find out index of the active button active_idx = self.rbutton_kwargs[i]['labels'].index( b.value_selected) # set button colors accordingly b.circles[not active_idx].set_color(bgcolor) b.circles[active_idx].set_color(fgcolor) def run(self): ''' set up slider functions ''' verts_func = self.update_lineverts colors_func = self.update_linecolors widths_func = self.update_linewidths # create iterable of slider funcs to zip with sliders slider_update_funcs = chain(repeat(verts_func, 5), [colors_func, ], repeat(widths_func, 3)) # set slider on_changed functions for s, f in zip(self.sliders, slider_update_funcs): s.on_changed(f) self.rbuttons[0].on_clicked(self.update_figcolors) self.rbuttons[1].on_clicked(self.update_linecolor_setting) plt.show()