def setup_axes(myg, num): """ create a grid of axes whose layout depends on the aspect ratio of the domain """ L_x = myg.xmax - myg.xmin L_y = myg.ymax - myg.ymin f = plt.figure(1) cbar_title = False if L_x > 2 * L_y: # we want num rows: axes = AxesGrid(f, 111, nrows_ncols=(num, 1), share_all=True, cbar_mode="each", cbar_location="top", cbar_pad="10%", cbar_size="25%", axes_pad=(0.25, 0.65), add_all=True, label_mode="L") cbar_title = True elif L_y > 2 * L_x: # we want num columns: rho |U| p e axes = AxesGrid(f, 111, nrows_ncols=(1, num), share_all=True, cbar_mode="each", cbar_location="right", cbar_pad="10%", cbar_size="25%", axes_pad=(0.65, 0.25), add_all=True, label_mode="L") else: # 2-d grid of plots ny = int(math.sqrt(num)) nx = num // ny axes = AxesGrid(f, 111, nrows_ncols=(nx, ny), share_all=True, cbar_mode="each", cbar_location="right", cbar_pad="2%", axes_pad=(0.65, 0.25), add_all=True, label_mode="L") return f, axes, cbar_title
def demo_grid_with_single_cbar_log(fig): """ A grid of 2x2 images with a single colorbar and log scaling """ grid = AxesGrid( fig, 111, # modified to be only subplot nrows_ncols=(2, 2), axes_pad=0.0, share_all=True, label_mode="L", cbar_location="top", cbar_mode="single", ) Z, extent = get_demo_image() Z -= np.min(Z) # modified to make data positive for i in range(4): im = grid[i].imshow(Z, extent=extent, interpolation="nearest", norm=LogNorm()) # modified to log-scale display #plt.colorbar(im, cax = grid.cbar_axes[0]) grid.cbar_axes[0].colorbar(im) for cax in grid.cbar_axes: cax.toggle_label(False) # This affects all axes as share_all = True. grid.axes_llc.set_xticks([-2, 0, 2]) grid.axes_llc.set_yticks([-2, 0, 2]) plt.show()
def img_pro(path): F = plt.figure(1, (15,20)) grid = AxesGrid(F, 111, nrows_ncols=(4,4), axes_pad=0, label_mode='1') for i in range(16): char = map_characters[i] list = [path+k for k in os.listdir(path) if char in k] image = cv2.imread(np.random.choice(list)) img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) pic = cv2.resize(image, (pic_size,pic_size)).astype('float32')/255 a = model.predict(pic.reshape(1, pic_size, pic_size, 3))[0] actual = char.split('_')[0].title() text = sorted(['{:s} : {:.1f}%'.format(map_characters[i].split('_')[0].title(), 100*v) for k,v in enumerate(a)], key=lambda x:float(x.split(':')[1].split('%')[0]), reverse=True)[:3] img = cv2.resize(img, (352, 352)) cv2.rectangle(img, (0,260), (215,352), (255,255,255), -1) font = cv2.FONT_HERSHEY_SIMPLEX cv2.putText(img, 'Actual : %s' % actual, (10,280), font, 0.7, (0,0,0), 2, cv2.LINE_AA) for k,t in enumerate(text): cv2.putText(img, t, (10, 300+k*18), font, 0.65, (0,0,0), 2, cv2.LINE_AA) grid[i].imshow(img) plt.show()
def demo_right_cbar(fig): """ A grid of 2x2 images. Each row has its own colorbar. """ grid = AxesGrid( fig, 122, # similar to subplot(122) nrows_ncols=(2, 2), axes_pad=0.10, label_mode="1", share_all=True, cbar_location="right", cbar_mode="edge", cbar_size="7%", cbar_pad="2%", ) Z, extent = get_demo_image() cmaps = [plt.get_cmap("spring"), plt.get_cmap("winter")] for i in range(4): im = grid[i].imshow(Z, extent=extent, interpolation="nearest", cmap=cmaps[i // 2]) if i % 2: grid.cbar_axes[i // 2].colorbar(im) for cax in grid.cbar_axes: cax.toggle_label(True) cax.axis[cax.orientation].set_label('Foo') # This affects all axes because we set share_all = True. grid.axes_llc.set_xticks([-2, 0, 2]) grid.axes_llc.set_yticks([-2, 0, 2])
def plot_filter(image_path, layer_name, output_dir): base_model = VGG16(weights='imagenet') x = load_images([image_path]) model = Model(input=base_model.input, output=base_model.get_layer(layer_name).output) layer_output = model.predict(x) side = int(layer_output.shape[-1]**0.5) fig = plt.figure() grid = AxesGrid(fig, 111, nrows_ncols=(side, side), axes_pad=0.0, share_all=True) for i in range(side**2): im = grid[i].imshow(layer_output[0, :, :, i], interpolation="nearest") grid.cbar_axes[0].colorbar(im) for cax in grid.cbar_axes: cax.toggle_label(False) for ax in grid.axes_all: ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) output_dir = Path(output_dir) fig_file = '{}-{}.pdf'.format(Path(image_path).stem, layer_name) plt.savefig(str(output_dir / fig_file))
def draw_grid(fig, data_array, x_min, x_max, y_min, y_max): """ A grid of 2x2 images with a single colorbar """ grid = AxesGrid(fig, 111, # similar to subplot(142) nrows_ncols=(1, 1), axes_pad=0.2, share_all=True, label_mode="L", cbar_location="right", cbar_mode="single", ) Z = data_array extent = (x_min, x_max, y_min, y_max) im = grid[0].imshow(Z, extent=extent, interpolation="nearest") grid[0].set_aspect(2.0, adjustable='box') grid.cbar_axes[0].colorbar(im) for cax in grid.cbar_axes: cax.toggle_label(True) # This affects all axes as share_all = True. grid.axes_llc.set_xticks(range(x_min, x_max, 10)) # Set x-axes sequence grid.axes_llc.set_yticks(range(y_min, y_max, 10))
def plot(D, i, flux_vol, xlabel, ylabel, name, sigma): flux = flux_vol.sum(axis=i) flux = nd.gaussian_filter(flux, sigma=(sigma, sigma), order=0) fig = plt.figure(figsize=(6.5, 6.0)) ax = AxesGrid( fig, 111, # similar to subplot(122) nrows_ncols=(1, 1), axes_pad=0.0, label_mode="1", share_all=True, cbar_location="right", cbar_mode="edge", cbar_size="2%", cbar_pad="0%", ) neg = ax[0].imshow(flux.T, origin='lower', extent=[D.x2[0], D.x2[-1], D.x3[0], D.x3[-1]]) ax.cbar_axes[0].colorbar(neg) # fig.subplots_adjust(top=0.99,bottom=0.11,left=0.11,right=0.99) ax.axes_llc.set_xlabel(xlabel, fontsize=20) ax.axes_llc.set_ylabel(ylabel, fontsize=20) fig.savefig(name)
def triple_plot(ds,fields,lfpps={},**kwargs): fig = plt.figure(figsize=(15,5)) grid = AxesGrid(fig, (0.075,0.075,0.85,0.85), nrows_ncols = (1, 3), axes_pad = 0.9, #label_mode = "1", share_all = True, cbar_location="right", cbar_mode="each", cbar_size="3%", cbar_pad="0%") #fields = [("flash",'soundspeed'), ("flash",'velocityz'), ("flash",'temperature'), ("flash",'density')] ps = [] for field in fields: fpp = lfpps.get(field,{}) if not(fpp): fpp = fpps.get(field,{}) ps.append(slice_plot(ds,field,plotprops=fpp,**kwargs)) for i, field in enumerate(fields): plot = ps[i].plots[field] plot.figure = fig plot.axes = grid[i].axes plot.cax = grid.cbar_axes[i] ps[i]._setup_plots() return fig
def plot_filters_single_channel(t): nrows=25 ncols=25 fig,ax = plt.subplots(nrows=25, ncols=25, figsize = (38,38)) count = 0 #looping through all the kernels in each channel grid = AxesGrid(fig, 111, nrows_ncols=(25, 25), axes_pad=0.05, cbar_mode='single', cbar_location='right', cbar_pad=0.1 ) count = 0 vmin = np.amin(t) vmax = np.amax(t) ax = [] for x in grid: ax.append(x) count = 0 #looping through all the kernels in each channel for i in range(1, 26): for j in range(1, 26): pcm = ax[count].imshow(t[i-1, j-1],vmin = vmin, vmax = vmax, cmap = "plasma") ax[count].axis('off') cbar = ax[count].cax.colorbar(pcm) count+=1 cbar = grid.cbar_axes[0].colorbar(pcm) plt.savefig("w_gate.png") return
def _create_figure(self, figsize=None, **kwargs): # creates a figure and then an AxesGrid in the 111 position fig = Figure(figsize=figsize) grid = AxesGrid(fig, 111, **kwargs) self.grid = grid self.next_grid = {grid: 0} return fig, grid
def plotting_weights(save_folder, filename, mat, true_states=None, estimated_states=None, soz_ch_ids=None, sel_win_num=None): title_fontSize = 40 # ictal_indices = np.argwhere(true_states!=0).astype(np.int) # ictal_indices = ictal_indices.reshape((ictal_indices.size,)) num_wind = len(mat) plot_num_rows = int(np.ceil(num_wind**0.5)) plot_num_cols = int(np.ceil(num_wind/plot_num_rows)) if not os.path.exists(save_folder): os.makedirs(save_folder) fig = plt.figure(num=None, figsize=(60, 40), dpi=120) grid = AxesGrid(fig, 111, nrows_ncols=(plot_num_rows, plot_num_cols), axes_pad=1, cbar_mode='single', cbar_location='right', cbar_pad=0.1) i = 0 for ax in grid: if(i>=len(mat)): break N = int(mat[i].size**0.5) im = ax.imshow(np.reshape(mat[i], (N, N))) if(true_states is not None and true_states[i] != 0 and estimated_states is not None): ax.set_title('--'+ str(estimated_states[i])+'--', fontsize=title_fontSize+5) elif(true_states is not None and true_states[i] == 0 and estimated_states is not None): ax.set_title(str(estimated_states[i]), fontsize=title_fontSize) i += 1 cbar = ax.cax.colorbar(im) cbar.ax.tick_params(labelsize=30) cbar = grid.cbar_axes[0].colorbar(im) plt.savefig(save_folder + 'W_' + filename + '.png')
def demo_grid_with_each_cbar(fig): """ A grid of 2x2 images. Each image has its own colorbar. """ grid = AxesGrid( F, 143, # similar to subplot(143) nrows_ncols=(2, 2), axes_pad=0.1, label_mode="1", share_all=True, cbar_location="top", cbar_mode="each", cbar_size="7%", cbar_pad="2%", ) Z, extent = get_demo_image() for i in range(4): im = grid[i].imshow(Z, extent=extent, interpolation="nearest") grid.cbar_axes[i].colorbar(im) for cax in grid.cbar_axes: cax.toggle_label(False) # This affects all axes because we set share_all = True. grid.axes_llc.set_xticks([-2, 0, 2]) grid.axes_llc.set_yticks([-2, 0, 2])
def double_plot(settings, ds): fig = plt.figure() grid = AxesGrid(fig, (0.09,0.09,0.8,0.8), nrows_ncols = (1, 2), axes_pad = 0.05, label_mode = "L", share_all = True, cbar_location="right", cbar_mode="single", cbar_size="5%", cbar_pad="1.2%") for i, fn in enumerate(fns): if settings["option"] == "top_down": slc = yt.SlicePlot(ds, 'z' , settings["field"], center=settings["center"], width=settings["width"], fontsize=settings["font"]) if settings["streamlines"]: slc.annotate_streamlines('velocity_x', 'velocity_y', density=1.5, factor=16, plot_args={ 'color': 'black', 'linewidth': 0.25}) if settings["option"] == "side_on": slc = yt.OffAxisSlicePlot(ds, settings["L"], settings["field"], center=settings["center"], north_vector=settings["north_vector"], width=settings["width"], fontsize=settings["font"]) slc.set_xlabel('x $\ (\mathrm{R}_{\odot})$') slc.set_ylabel('z $\ (\mathrm{R}_{\odot})$') if settings["streamlines"]: slc.annotate_streamlines('magnetic_field_x', 'magnetic_field_z', density=1.5, factor=16, plot_args={ 'color': 'white', 'linewidth': 0.75}) slc.set_cmap(field=settings["field"], cmap='jet') slc.set_zlim(settings["field"], settings["lim"][0], settings["lim"][1]) plot = slc.plots[settings["field"]] plot.figure = fig plot.axes = grid[i].axes plot.cax = grid.cbar_axes[i] slc._setup_plots() slc.save("plots/"+settings["save_name"]+".pdf")
def AX_perturbations(fig, perturbations): """ A grid of 1x10 images representing the perturbations of the generated AXs - with a single coolwarm colorbar at the right """ grid = AxesGrid( fig, 211, # similar to subplot(211) nrows_ncols=(1, 10), axes_pad=0.0, share_all=True, label_mode="1", cbar_location="right", cbar_mode="single", cbar_size="7%", cbar_pad="3%", ) for i in range(nb_classes): img = perturbations[i].reshape([28, 28]) im = grid[i].imshow(img, interpolation="nearest", cmap=cm.coolwarm, vmin=-1., vmax=1.) grid[i].tick_params(which='both', bottom='off', left='off', labelbottom='off', labelleft='off') import matplotlib as mpl norm_ = mpl.colors.Normalize(vmin=-1., vmax=1.) grid.cbar_axes[0].colorbar(im, norm=norm_) #grid.cbar_axes[0].set_yticklabels(['-1', '0', '1']) grid.cbar_axes[0].set_yticks((-1, 0, 1))
def AX_actual(fig, adv_x, top_1, confidence, ylabel): """ A grid of 1x10 images displaying the actual AXs along with their predicted labels and confidences""" grid = AxesGrid( fig, 111, # similar to subplot(212) nrows_ncols=(1, 10), axes_pad=0.0, share_all=True, label_mode="all") for i in range(nb_classes): img = adv_x[i].reshape([28, 28]) im = grid[i].imshow(img, cmap='gray') grid[i].tick_params(which='both', bottom='off', left='off', labelbottom='off', labelleft='off') conf = str( Decimal(str(confidence[i])).quantize(Decimal('0.01'), rounding=ROUND_DOWN)) #xlabel = str(top_1[i]) + " " + "(" + '{0:.2f}'.format(confidence[i]) + ')' xlabel = str(top_1[i]) + " " + "(" + conf + ")" grid[i].set_xlabel(xlabel, labelpad=2.0, fontsize=12) grid[9].yaxis.set_label_position("right") # ylabel should be a string grid[9].set_ylabel(ylabel, labelpad=14.0, fontsize=15, rotation=270)
def plot_multiple_model_weights(weights_to_plot): """ Plot the weights of different models side by side as a heat maps. """ models_weights = np.array( [rectangularfy(weights) for weights in weights_to_plot]) vmax = models_weights.max() vmin = models_weights.min() fig = plt.figure() grid = AxesGrid( fig, 111, nrows_ncols=(1, len(models_weights)), axes_pad=0.05, share_all=True, label_mode="L", cbar_location="right", cbar_mode="single", ) for axis, model_weights in zip(grid, models_weights): heatmap = axis.imshow(model_weights, cmap="coolwarm", vmax=vmax, vmin=vmin) axis.set_xticks([]) axis.set_yticks([]) grid.cbar_axes[0].colorbar(heatmap) plt.show()
def displayMaps(self, map_list, figFilename, n_rows=2): print 'map_list: ', (np.array(map_list)).shape fig = plt.figure(figsize=(200, 200)) grid = AxesGrid( fig, 111, nrows_ncols=(n_rows, int(np.round(len(map_list) / n_rows))), axes_pad=0.02, share_all=True, label_mode="L", cbar_location="right", cbar_mode="each", ) vmin = float('inf') vmax = -float('inf') for cur_map in map_list: cur_map_min = np.min(cur_map) cur_map_max = np.max(cur_map) vmin = min(vmin, cur_map_min) vmax = max(vmax, cur_map_max) print 'DEBUG:vmax, vmin: ', vmax, vmin for cur_map, ax in zip(map_list, grid): im = ax.imshow(cur_map, vmin=vmin, vmax=vmax) #im = ax.imshow(cur_map, vmin=self.noise_floor_dB, vmax=0.0) grid.cbar_axes[0].colorbar(im) #plt.show(block = False) #plt.colorbar() plt.savefig(figFilename)
def demo_grid_with_each_cbar_labelled(fig): """ A grid of 2x2 images. Each image has its own colorbar. """ grid = AxesGrid( F, 144, # similar to subplot(144) nrows_ncols=(2, 2), axes_pad=(0.45, 0.15), label_mode="1", share_all=True, cbar_location="right", cbar_mode="each", cbar_size="7%", cbar_pad="2%", ) Z, extent = get_demo_image() # Use a different colorbar range every time limits = ((0, 1), (-2, 2), (-1.7, 1.4), (-1.5, 1)) for i in range(4): im = grid[i].imshow(Z, extent=extent, interpolation="nearest", vmin=limits[i][0], vmax=limits[i][1]) grid.cbar_axes[i].colorbar(im) for i, cax in enumerate(grid.cbar_axes): cax.set_yticks((limits[i][0], limits[i][1])) # This affects all axes because we set share_all = True. grid.axes_llc.set_xticks([-2, 0, 2]) grid.axes_llc.set_yticks([-2, 0, 2])
def quad_plot(ds,fields,width=(5000,"km")): fig = plt.figure() grid = AxesGrid(fig, (0.075,0.075,0.85,0.85), nrows_ncols = (2, 2), axes_pad = 0.9, #label_mode = "1", share_all = True, cbar_location="right", cbar_mode="each", cbar_size="3%", cbar_pad="0%") #fields = [("flash",'soundspeed'), ("flash",'velocityz'), ("flash",'temperature'), ("flash",'density')] ps = [] for field in fields: ps.append(slice_plot(ds,field,width=width)) for i, field in enumerate(fields): plot = ps[i].plots[field] plot.figure = fig plot.axes = grid[i].axes plot.cax = grid.cbar_axes[i] ps[i]._setup_plots() return fig
def multiplot(): field = 'density' view = 'projection' field = 'O_p5_number_density' track = Table.read('complete_track', format='ascii') track.sort('col1') outs = [x+55 for x in range(400)] for n in outs: fig = plt.figure() grid = AxesGrid(fig, (0.5,0.5,1.5,1.5), nrows_ncols = (1, 5), axes_pad = 0.1, label_mode = "1", share_all = True, cbar_location="right", cbar_mode="edge", cbar_size="5%", cbar_pad="0%") strset = 'DD00'+str(n) if (n > 99): strset = 'DD0'+str(n) fields = [field, field, field, field, field] snaps = ['nref10_track_2/'+strset+'/'+strset, 'nref10_track_lowfdbk_1/'+strset+'/'+strset, 'nref10_track_lowfdbk_2/'+strset+'/'+strset, 'nref10_track_lowfdbk_3/'+strset+'/'+strset, 'nref10_track_lowfdbk_4/'+strset+'/'+strset] for i, (field, snap) in enumerate(zip(fields, snaps)): ds = yt.load(snap) zsnap = ds.get_parameter('CosmologyCurrentRedshift') trident.add_ion_fields(ds, ions=['C IV', 'O VI','H I','Si III']) centerx = np.interp(zsnap, track['col1'], 0.5*(track['col2']+track['col5']) ) centery = np.interp(zsnap, track['col1'], track['col3']+30./143886.) centerz = np.interp(zsnap, track['col1'], 0.5*(track['col4']+track['col7']) ) center = [centerx, centery, centerz] box = ds.r[ center[0]-200./143886:center[0]+200./143886, center[1]-250./143886.:center[1]+250./143886., center[2]-40./143886.:center[2]+40./143886.] # projection p = yt.ProjectionPlot(ds, 'z', field, center=center, width=((120,'kpc'),(240,'kpc')), data_source=box) if (field == 'density'): p.set_unit('density', 'Msun / pc**2') p.set_zlim('density', 0.01, 1000) if ('O_p5' in field): p.set_zlim("O_p5_number_density",1e11,1e15) if (i < 1): p.annotate_timestamp(corner='upper_left', redshift=True, draw_inset_box=True, text_args={'color':'white', 'size':'small'} ) # This forces the ProjectionPlot to redraw itself on the AxesGrid axes. plot = p.plots[field] plot.figure = fig plot.axes = grid[i].axes p._setup_plots() # Finally, redraw the plot. plt.savefig(strset+'_multiplot_'+field+'_projection.png', bbox_inches='tight')
def demo_grid_with_single_cbar(fig): """ A grid of 2x2 images with a single colorbar """ grid = AxesGrid( fig, 142, # similar to subplot(142) nrows_ncols=(2, 2), axes_pad=0.0, share_all=True, label_mode="L", cbar_location="top", cbar_mode="single", ) Z, extent = get_demo_image() for i in range(4): im = grid[i].imshow(Z, extent=extent, interpolation="nearest") #plt.colorbar(im, cax = grid.cbar_axes[0]) grid.cbar_axes[0].colorbar(im) for cax in grid.cbar_axes: cax.toggle_label(False) # This affects all axes as share_all = True. grid.axes_llc.set_xticks([-2, 0, 2]) grid.axes_llc.set_yticks([-2, 0, 2])
def displayMaps(self, map_list, figFilename=None, n_rows=1, title=''): fig = plt.figure() grid = AxesGrid( fig, 111, nrows_ncols=(n_rows, int(np.round(len(map_list) / n_rows))), axes_pad=0.01, share_all=True, label_mode="L", cbar_location="right", cbar_mode="single", ) vmin = float('inf') vmax = -float('inf') for cur_map in map_list: cur_map_min = np.min(cur_map) cur_map_max = np.max(cur_map) vmin = min(vmin, cur_map_min) vmax = max(vmax, cur_map_max) #print 'DEBUG:vmax, vmin: ',vmax, vmin for cur_map, ax in zip(map_list, grid): im = ax.imshow(cur_map, vmin=vmin, vmax=vmax) grid.cbar_axes[0].colorbar(im) if figFilename is not None: plt.savefig(figFilename) else: plt.show(block=False)
def heatmap_in_one_figure(vals, pars, cmap=None): """ vals: lists of dataframes pars: nrow and ncol of plots in the figure """ from mpl_toolkits.axes_grid1 import AxesGrid fig = plt.figure() grid = AxesGrid( fig, 111, nrows_ncols=pars, axes_pad=0.05, share_all=True, label_mode="L", cbar_location="right", cbar_mode="single", ) for val, ax in zip(vals, grid): myplot3d = plot3D(val) hm = ax.pcolor(myplot3d.x, myplot3d.y, myplot3d.z, vmin=0, vmax=1, cmap=cmap) grid.cbar_axes[0].colorbar(hm) return fig, ax
def demo_bottom_cbar(fig): """ A grid of 2x2 images with a colorbar for each column. """ grid = AxesGrid( fig, 121, # similar to subplot(121) nrows_ncols=(2, 2), axes_pad=0.10, share_all=True, label_mode="1", cbar_location="bottom", cbar_mode="edge", cbar_pad=0.25, cbar_size="15%", direction="column") Z, extent = get_demo_image() cmaps = [plt.get_cmap("autumn"), plt.get_cmap("summer")] for i in range(4): im = grid[i].imshow(Z, extent=extent, interpolation="nearest", cmap=cmaps[i // 2]) if i % 2: grid.cbar_axes[i // 2].colorbar(im) for cax in grid.cbar_axes: cax.toggle_label(True) cax.axis[cax.orientation].set_label("Bar") # This affects all axes as share_all = True. grid.axes_llc.set_xticks([-2, 0, 2]) grid.axes_llc.set_yticks([-2, 0, 2])
def displayMaps(self, map_list, figFilename, n_rows = 1): fig = plt.figure() fig.suptitle("Maximum %-diff in input map-pairs", fontsize=16) nrows_ncols = (n_rows, int( np.round( len(map_list)/n_rows) )) grid = AxesGrid(fig, 111, nrows_ncols= nrows_ncols, axes_pad=0.1, share_all=True, label_mode="L", cbar_location="right", cbar_mode="single", ) vmin = float('inf') vmax = -float('inf') for cur_map in map_list: cur_map_min = np.min(cur_map) cur_map_max = np.max(cur_map) vmin = min(vmin, cur_map_min) vmax = max(vmax, cur_map_max) if (vmax == vmin): vmax += 1 for cur_map, ax in zip(map_list, grid): im = ax.imshow(cur_map, vmin=vmin, vmax=vmax) grid.cbar_axes[0].colorbar(im) #plt.show(block = False) plt.savefig(figFilename)
def create_maps(self): g = self.chan_map.geometry # this is an over-estimate of inches per pixel, maybe # find another rule inch_per_pixel = 0.039 * self.mm_per_pixel # size of array map: width, height img_size = g[1] * inch_per_pixel, g[0] * inch_per_pixel figsize = self.map_col * img_size[0], self.map_row * img_size[1] figwin = SimpleFigure(figsize=figsize) fig = figwin.figure text_size_inch = 11 / fig.dpi grid = AxesGrid( fig, 111, nrows_ncols=(self.map_row, self.map_col), axes_pad=1.5 * text_size_inch, cbar_mode='single', cbar_location='right', cbar_pad='2%', cbar_size='4%' ) for ax in grid: ax.axis('off') self._grid = grid self._g_idx = 0 self._cbar = None # hold onto this or it disappears figwin.show() self._figwin = figwin
def plotImgMosaic(imgOrig, extent, origin='upper', T=0): fig = plt.figure() grid = AxesGrid( fig, 111, # similar to subplot(111) nrows_ncols=(3, 2), axes_pad=0.0, share_all=True, label_mode="all", cbar_location="top", cbar_mode="single", ) img = dict(imgOrig) for im in img: img[im] = np.squeeze(img[im]) if T: img[im] = img[im].T ij = "21" toPlot = 20 * np.log10(np.abs(img[ij])) im = grid[0].imshow(toPlot, cmap='jet', extent=extent, origin=origin) grid[0].set_ylabel(ij) ij = "41" toPlot = 20 * np.log10(np.abs(img[ij])) im = grid[1].imshow(toPlot, cmap='jet', extent=extent, origin=origin) grid[1].set_ylabel(ij) grid[1].yaxis.set_ticks_position('right') grid[1].yaxis.set_label_position('right') ij = "31" toPlot = 20 * np.log10(np.abs(img[ij])) im = grid[2].imshow(toPlot, cmap='jet', extent=extent, origin=origin) grid[2].set_ylabel(ij) ij = "42" toPlot = 20 * np.log10(np.abs(img[ij])) im = grid[3].imshow(toPlot, cmap='jet', extent=extent, origin=origin) grid[3].set_ylabel(ij) grid[3].yaxis.set_ticks_position('right') grid[3].yaxis.set_label_position('right') ij = "32" toPlot = 20 * np.log10(np.abs(img[ij])) im = grid[4].imshow(toPlot, cmap='jet', extent=extent, origin=origin) grid[4].set_ylabel(ij) ij = "43" toPlot = 20 * np.log10(np.abs(img[ij])) im = grid[5].imshow(toPlot, cmap='jet', extent=extent, origin=origin) grid[5].set_ylabel(ij) grid[5].yaxis.set_ticks_position('right') grid[5].yaxis.set_label_position('right') grid.cbar_axes[0].colorbar(im) return fig, grid
def plot_activation_gradients(all_timesteps_activations: np.array, neuron_heatmap_size: tuple, show_title=True, absolute=True, save=None): """ Plot the changes in activation values between time steps as heat maps for one single sample. """ num_timesteps = len(all_timesteps_activations) assert all([type(out) in (torch.Tensor, np.ndarray) for out in all_timesteps_activations]), \ "This function only takes all the activations for all the time steps of a single sample." fig = plt.figure() last_activations = all_timesteps_activations[0] grid = AxesGrid( fig, 111, nrows_ncols=(1, num_timesteps - 1), axes_pad=0.05, share_all=True, label_mode="L", cbar_location="right", cbar_mode="single", ) for t, (axis, current_activations) in enumerate( zip(grid, all_timesteps_activations[1:])): activation_gradients = current_activations - last_activations vmin, vmax = -2, 2 colormap = 'coolwarm' if absolute: vmin = 0 colormap = "Reds" heatmap = axis.imshow( activation_gradients.reshape(*neuron_heatmap_size), cmap=colormap, vmin=vmin, vmax=vmax) axis.set_xlabel("t={} -> t={}".format(t, t + 1)) axis.set_xticks([]) axis.set_yticks([]) last_activations = current_activations grid.cbar_axes[0].colorbar(heatmap) if show_title: fig.suptitle("Activation value gradients over {} time steps".format( num_timesteps)) if save is None: plt.show() else: plt.savefig(save, bbox_inches="tight") plt.close()
def accumulate_patches_into_heatmaps(self, all_test_output, outpath_prefix=''): outpath = "plots/%s_%s.png" % ( outpath_prefix, path.splitext(path.basename( self.test_imagepath))[0]) # http://matplotlib.org/examples/axes_grid/demo_axes_grid.html fig = plt.figure() grid = AxesGrid( fig, 143, # similar to subplot(143) nrows_ncols=(1, 1)) orig_img = imread(self.test_imagepath + '.png') grid[0].imshow(orig_img) grid = AxesGrid( fig, 144, # similar to subplot(144) nrows_ncols=(2, 2), axes_pad=0.15, label_mode="1", share_all=True, cbar_location="right", cbar_mode="each", cbar_size="7%", cbar_pad="2%", ) for klass in xrange(all_test_output.shape[1]): accumulator = numpy.zeros(self.ds.image_shape[:2]) normalizer = numpy.zeros(self.ds.image_shape[:2]) for n in xrange(self.num_patch_centers): i_start, i_end, j_start, j_end = self.nth_patch(n) accumulator[i_start:i_end, j_start:j_end] += all_test_output[n, klass] normalizer[i_start:i_end, j_start:j_end] += 1 normalized_img = accumulator / normalizer im = grid[klass].imshow(normalized_img, interpolation="nearest", vmin=0, vmax=1) grid.cbar_axes[klass].colorbar(im) grid.axes_llc.set_xticks([]) grid.axes_llc.set_yticks([]) print("Saving figure as: %s" % outpath) plt.savefig(outpath, dpi=600, bbox_inches='tight')
def plot_pathological_imgs(): fig = plt.figure() grid = AxesGrid(fig, 111, nrows_ncols = (1, 4)) names = ['23050_right.png', '2468_left.png', '15450_left.png', '406_left.png'] imgs = [imread(n) for n in names] [grid[i].imshow(imgs[i]) for i in range(len(imgs))] plt.axis('off') plt.savefig('out.png', dpi=300)
def main(args) : import os.path if args.trackTitles is None : args.trackTitles = [os.path.dirname(filename) for filename in args.inputDataFiles] if len(args.inputDataFiles) == 0 : print "WARNING: No corner control files given!" if len(args.trackTitles) != len(args.inputDataFiles) : raise ValueError("The number of TITLEs does not match the number" " of INPUTFILEs.") if args.statName is not None and args.statLonLat is None : statData = ByName(args.statName)[0] args.statLonLat = (statData['LON'], statData['LAT']) if args.layout is None : args.layout = (1, len(args.inputDataFiles)) if args.figsize is None : args.figsize = plt.figaspect(float(args.layout[0]) / args.layout[1]) if args.simTagFiles is None : args.simTagFiles = [] polyfiles = args.polys cornerVolumes = [ReadCorners(inFileName, os.path.dirname(inFileName))['volume_data'] for inFileName in args.inputDataFiles] polyData = [_load_verts(f, list(vol['stormCells'] for vol in vols)) for f, vols in zip(polyfiles, cornerVolumes)] multiTags = [ReadSimTagFile(fname) for fname in args.simTagFiles] if len(multiTags) == 0 : multiTags = [None] if len(multiTags) < len(cornerVolumes) : # Rudimentary broadcasting tagMult = max(int(len(cornerVolumes) // len(multiTags)), 1) multiTags = multiTags * tagMult if args.statLonLat is not None : for vols in cornerVolumes : for vol in vols : CoordinateTransform(vol['stormCells'], args.statLonLat[0], args.statLonLat[1]) for verts in polyData: CoordinateTransform(verts, args.statLonLat[0], args.statLonLat[1]) showMap = (args.statLonLat is not None and args.displayMap) showRadar = (args.statLonLat is not None and args.radarFile is not None) theFig = plt.figure(figsize=args.figsize) grid = AxesGrid(theFig, 111, nrows_ncols=args.layout, share_all=True, axes_pad=0.32) theAnim, radAnim = MakeCornerPlots(theFig, grid, cornerVolumes, args.trackTitles, showMap, showRadar, tail=args.tail, startFrame=args.startFrame, endFrame=args.endFrame, radarFiles=args.radarFile, fade=args.fade, multiTags=multiTags, tag_filters=args.filters) polyAnims = [] for ax, verts in zip(grid, polyData): from matplotlib.animation import ArtistAnimation polyAnim = ArtistAnimation(theFig, _to_polygons(polys[startFrame:endFrame + 1], ax), event_source=theTimer) polyAnims.append(polyAnim) if args.xlims is not None and np.prod(grid.get_geometry()) > 0 : grid[0].set_xlim(args.xlims) if args.ylims is not None and np.prod(grid.get_geometry()) > 0 : grid[0].set_ylim(args.ylims) if args.saveImgFile is not None : if radAnim is not None : radAnim = [radAnim] theAnim.save(args.saveImgFile, extra_anim=radAnim + polyAnims) if args.doShow : plt.show()
def main(args) : import os.path # for os.path import matplotlib.pyplot as plt from mpl_toolkits.axes_grid1 import AxesGrid inputDataFiles = [] titles = [] simTagFiles = [] if args.simName is not None : dirName = os.path.join(args.directory, args.simName) simParams = ParamUtils.ReadSimulationParams(os.path.join(dirName, "simParams.conf")) inputDataFiles.append(os.path.join(dirName, simParams['inputDataFile'])) titles.append(args.simName) simTagFiles.append(os.path.join(dirName, simParams['simTagFile'])) # Add on any files specified at the command-line inputDataFiles += args.inputDataFiles titles += args.inputDataFiles if args.simTagFiles is not None : simTagFiles += args.simTagFiles if len(inputDataFiles) == 0 : print "WARNING: No inputDataFiles given or found!" if len(titles) != len(inputDataFiles) : raise ValueError("The number of TITLEs does not match the" " number of INPUTFILEs.") if len(simTagFiles) < len(inputDataFiles) : # Not an error, just simply append None simTagFiles.append([None] * (len(inputDataFiles) - len(simTagFiles))) if args.statName is not None and args.statLonLat is None : statData = ByName(args.statName)[0] args.statLonLat = (statData['LON'], statData['LAT']) if args.layout is None : args.layout = (1, len(inputDataFiles)) if args.figsize is None : args.figsize = plt.figaspect(float(args.layout[0]) / args.layout[1]) cornerVolumes = [ReadCorners(inFileName, os.path.dirname(inFileName))['volume_data'] for inFileName in inputDataFiles] multiTags = [(ReadSimTagFile(fname) if fname is not None else None) for fname in simTagFiles] for vols, simTags in zip(cornerVolumes, multiTags) : keeperIDs = process_tag_filters(simTags, args.filters) if keeperIDs is None : continue for vol in vols : vol['stormCells'] = FilterTrack(vol['stormCells'], cornerIDs=keeperIDs) if args.statLonLat is not None : for vols in cornerVolumes : for vol in vols : CoordinateTransform(vol['stormCells'], args.statLonLat[0], args.statLonLat[1]) theFig = plt.figure(figsize=args.figsize) grid = AxesGrid(theFig, 111, nrows_ncols=args.layout, share_all=True, axes_pad=0.32) # A list to hold the CircleCollection arrays, it will have length # of max(tLims) - min(tLims) + 1 allCorners = None if args.trackFile is not None : (tracks, falarms) = FilterMHTTracks(*ReadTracks(args.trackFile)) if args.statLonLat is not None : CoordinateTransform(tracks + falarms, args.statLonLat[0], args.statLonLat[1]) (xLims, yLims, frameLims) = DomainFromTracks(tracks + falarms) else : volumes = [] for aVol in cornerVolumes : volumes.extend(aVol) (xLims, yLims, tLims, frameLims) = DomainFromVolumes(volumes) showMap = (args.statLonLat is not None and args.displayMap) if showMap : bmap = Basemap(projection='cyl', resolution='l', suppress_ticks=False, llcrnrlat=yLims[0], llcrnrlon=xLims[0], urcrnrlat=yLims[1], urcrnrlon=xLims[1]) startFrame = args.startFrame endFrame = args.endFrame tail = args.tail if startFrame is None : startFrame = frameLims[0] if endFrame is None : endFrame = frameLims[1] if tail is None : tail = 0 # A common event_source for synchronizing all the animations theTimer = None # Make the corners big big = False if args.radarFile is not None and args.statLonLat is not None : if endFrame - frameLims[0] >= len(args.radarFile) : # Not enough radar files, so truncate the tracks. endFrame = (len(args.radarFile) + frameLims[0]) - 1 files = args.radarFile[startFrame - frameLims[0]:(endFrame + 1) - frameLims[0]] radAnim = RadarAnim(theFig, files) theTimer = radAnim.event_source for ax in grid : radAnim.add_axes(ax, alpha=0.6, zorder=0) # Radar images make it difficult to see corners, so make 'em big big = True else : radAnim = None theAnim = CornerAnimation(theFig, endFrame - startFrame + 1, tail=tail, interval=250, blit=False, event_source=theTimer, fade=args.fade) for (index, volData) in enumerate(cornerVolumes) : curAxis = grid[index] if showMap : PlotMapLayers(bmap, mapLayers, curAxis, zorder=0.1) volFrames = [frameVol['frameNum'] for frameVol in volData] startIdx = volFrames.index(startFrame) endIdx = volFrames.index(endFrame) volTimes = [frameVol['volTime'] for frameVol in volData] startT = volTimes[startIdx] endT = volTimes[endIdx] corners = PlotCorners(volData, (startT, endT), axis=curAxis, big=big) #curAxis.set_aspect("equal", 'datalim') #curAxis.set_aspect("equal") curAxis.set_title(titles[index]) if not showMap : curAxis.set_xlabel("X") curAxis.set_ylabel("Y") else : curAxis.set_xlabel("Longitude") curAxis.set_ylabel("Latitude") theAnim.AddCornerVolume(corners) if args.xlims is not None and np.prod(grid.get_geometry()) > 0 : grid[0].set_xlim(args.xlims) if args.ylims is not None and np.prod(grid.get_geometry()) > 0 : grid[0].set_ylim(args.ylims) if args.saveImgFile is not None : if radAnim is not None : radAnim = [radAnim] theAnim.save(args.saveImgFile, extra_anim=radAnim) if args.doShow : plt.show()
def main(args) : import os.path import glob # for globbing import matplotlib.pyplot as plt from mpl_toolkits.axes_grid1 import AxesGrid if args.bw_mode : BW_mode() # from TrackPlot module # FIXME: Currently, the code allows for trackFiles to be listed as well # as providing a simulation (which trackfiles are automatically # grabbed). Both situations can not be handled right now, though. trackFiles = [] trackTitles = [] if args.simName is not None : dirName = os.path.join(args.directory, args.simName) simParams = ParamUtils.ReadSimulationParams(os.path.join(dirName, "simParams.conf")) if args.trackRuns is not None : simParams['trackers'] = ExpandTrackRuns(simParams['trackers'], args.trackRuns) trackFiles = [os.path.join(dirName, simParams['result_file'] + '_' + aTracker) for aTracker in simParams['trackers']] if args.trackTitles is None : trackTitles = simParams['trackers'] else : trackTitles = args.trackTitles if args.truthTrackFile is None : args.truthTrackFile = os.path.join(dirName, simParams['noisyTrackFile']) if args.simTagFile is None : args.simTagFile = os.path.join(dirName, simParams['simTagFile']) trackFiles += args.trackFiles if args.trackTitles is None : trackTitles += args.trackFiles else : trackTitles += args.trackTitles if len(trackFiles) == 0 : print "WARNING: No trackFiles given or found!" if len(trackTitles) != len(trackFiles) : raise ValueError("The number of TITLEs do not match the" " number of TRACKFILEs.") if args.statName is not None and args.statLonLat is None : statData = ByName(args.statName)[0] args.statLonLat = (statData['LON'], statData['LAT']) if args.layout is None : args.layout = (1, len(trackFiles)) if args.figsize is None : args.figsize = plt.figaspect(float(args.layout[0]) / args.layout[1]) trackerData = [FilterMHTTracks(*ReadTracks(trackFile)) for trackFile in trackFiles] keeperIDs = None if args.simTagFile is not None : simTags = ParamUtils.ReadSimTagFile(args.simTagFile) keeperIDs = ParamUtils.process_tag_filters(simTags, args.filters) if args.statLonLat is not None : for aTracker in trackerData : CoordinateTransform(aTracker[0] + aTracker[1], args.statLonLat[0], args.statLonLat[1]) theFig = plt.figure(figsize=args.figsize) grid = AxesGrid(theFig, 111, nrows_ncols=args.layout, aspect=False, share_all=True, axes_pad=0.35) if args.truthTrackFile is not None : (true_tracks, true_falarms) = FilterMHTTracks(*ReadTracks(args.truthTrackFile)) if args.statLonLat is not None : CoordinateTransform(true_tracks + true_falarms, args.statLonLat[0], args.statLonLat[1]) true_AssocSegs = CreateSegments(true_tracks) true_FAlarmSegs = CreateSegments(true_falarms) if keeperIDs is not None : true_AssocSegs = FilterSegments(keeperIDs, true_AssocSegs) true_FAlarmSegs = FilterSegments(keeperIDs, true_FAlarmSegs) (xLims, yLims, frameLims) = DomainFromTracks(true_tracks + true_falarms) else : true_AssocSegs = None true_FAlarmSegs = None stackedTracks = [] for aTracker in trackerData : stackedTracks += aTracker[0] + aTracker[1] (xLims, yLims, frameLims) = DomainFromTracks(stackedTracks) endFrame = args.endFrame tail = args.tail if endFrame is None : endFrame = frameLims[1] if tail is None : tail = endFrame - frameLims[0] startFrame = endFrame - tail showMap = (args.statLonLat is not None and args.displayMap) if args.radarFile is not None and args.statLonLat is not None : if len(args.radarFile) > 1 and args.endFrame is not None : args.radarFile = args.radarFile[args.endFrame] else : args.radarFile = args.radarFile[-1] raddata = LoadRastRadar(args.radarFile) else : raddata = None if showMap : bmap = Basemap(projection='cyl', resolution='i', suppress_ticks=False, llcrnrlat=yLims[0], llcrnrlon=xLims[0], urcrnrlat=yLims[1], urcrnrlon=xLims[1]) for index, (tracks, falarms) in enumerate(trackerData) : curAxis = grid[index] if raddata is not None : MakeReflectPPI(raddata['vals'][0], raddata['lats'], raddata['lons'], meth='pcmesh', ax=curAxis, colorbar=False, axis_labels=False, zorder=0, alpha=0.6) if showMap : PlotMapLayers(bmap, mapLayers, curAxis) if true_AssocSegs is not None and true_FAlarmSegs is not None : trackAssocSegs = CreateSegments(tracks) trackFAlarmSegs = CreateSegments(falarms) if keeperIDs is not None : trackAssocSegs = FilterSegments(keeperIDs, trackAssocSegs) trackFAlarmSegs = FilterSegments(keeperIDs, trackFAlarmSegs) truthtable = CompareSegments(true_AssocSegs, true_FAlarmSegs, trackAssocSegs, trackFAlarmSegs) PlotSegments(truthtable, (startFrame, endFrame), axis=curAxis, fade=args.fade) else : if keeperIDs is not None : filtFunc = lambda trk: FilterTrack(trk, cornerIDs=keeperIDs) tracks = map(filtFunc, tracks) falarms = map(filtFunc, falarms) CleanupTracks(tracks, falarms) PlotPlainTracks(tracks, falarms, startFrame, endFrame, axis=curAxis, fade=args.fade) #curAxis.set_xlim(xLims) #curAxis.set_ylim(yLims) #curAxis.set_aspect("equal", 'datalim') #curAxis.set_aspect("equal") curAxis.set_title(trackTitles[index]) if not showMap : curAxis.set_xlabel("X") curAxis.set_ylabel("Y") else : curAxis.set_xlabel("Longitude") curAxis.set_ylabel("Latitude") if args.xlims is not None and np.prod(grid.get_geometry()) > 0 : grid[0].set_xlim(args.xlims) if args.ylims is not None and np.prod(grid.get_geometry()) > 0 : grid[0].set_ylim(args.ylims) if args.saveImgFile is not None : theFig.savefig(args.saveImgFile, bbox_inches='tight') if args.doShow : plt.show()
def main(args) : if args.bw_mode : BW_mode() if len(args.trackFiles) == 0 : print "WARNING: No trackFiles given!" if len(args.truthTrackFile) == 0 : print "WARNING: No truth trackFiles given!" if args.trackTitles is None : args.trackTitles = args.trackFiles else : if len(args.trackTitles) != len(args.trackFiles) : raise ValueError("The number of TITLEs does not match the number" " of TRACKFILEs") if args.statName is not None and args.statLonLat is None : statData = ByName(args.statName)[0] args.statLonLat = (statData['LON'], statData['LAT']) if args.layout is None : args.layout = (1, max(len(args.trackFiles), len(args.truthTrackFile))) if args.figsize is None : args.figsize = plt.figaspect(float(args.layout[0]) / args.layout[1]) if args.simTagFiles is None : args.simTagFiles = [] trackerData = [FilterMHTTracks(*ReadTracks(trackFile)) for trackFile in args.trackFiles] truthData = [FilterMHTTracks(*ReadTracks(trackFile)) for trackFile in args.truthTrackFile] multiTags = [ReadSimTagFile(fname) for fname in args.simTagFiles] if len(multiTags) == 0 : multiTags = [None] if args.statLonLat is not None : for aTracker in trackerData + truthData : CoordinateTransform(aTracker[0] + aTracker[1], args.statLonLat[0], args.statLonLat[1]) if len(trackerData) != len(truthData) : # Basic broadcasting needed! if len(truthData) > len(trackerData) : # Need to extend track data to match with the number of truth sets if len(truthData) % len(trackerData) != 0 : raise ValueError("Can't extend TRACKFILE list to match with" " the TRUTHFILE list!") else : # Need to extend truth sets to match with the number of track data if len(trackerData) % len(truthData) != 0 : raise ValueError("Can't extend TRUTHFILE list to match with" " the TRACKFILE list!") trkMult = max(int(len(truthData) // len(trackerData)), 1) trthMult = max(int(len(trackerData) // len(truthData)), 1) trackerData = trackerData * trkMult truthData = truthData * trthMult tagMult = max(int(len(truthData) // len(multiTags)), 1) multiTags = multiTags * tagMult args.trackTitles = args.trackTitles * trkMult theFig = plt.figure(figsize=args.figsize) grid = AxesGrid(theFig, 111, nrows_ncols=args.layout, aspect=False, share_all=True, axes_pad=0.45) showMap = (args.statLonLat is not None and args.displayMap) if args.radarFile is not None and args.statLonLat is not None : if len(args.radarFile) > 1 and args.endFrame is not None : args.radarFile = args.radarFile[args.endFrame] else : args.radarFile = args.radarFile[-1] data = LoadRastRadar(args.radarFile) for ax in grid : MakeReflectPPI(data['vals'][0], data['lats'], data['lons'], meth='pcmesh', ax=ax, colorbar=False, axis_labels=False, zorder=0, alpha=0.6) MakeComparePlots(grid, trackerData, truthData, args.trackTitles, showMap, endFrame=args.endFrame, tail=args.tail, fade=args.fade, multiTags=multiTags, tag_filters=args.filters) if args.xlims is not None and np.prod(grid.get_geometry()) > 0 : grid[0].set_xlim(args.xlims) if args.ylims is not None and np.prod(grid.get_geometry()) > 0 : grid[0].set_ylim(args.ylims) if args.saveImgFile is not None : theFig.savefig(args.saveImgFile) if args.doShow : plt.show()
def main(args) : import os.path # for os.path.join() import glob # for globbing if args.bw_mode : BW_mode() # from TrackPlot module # FIXME: Currently, the code allows for trackFiles to be listed as well # as providing a simulation (which trackfiles are automatically # grabbed). Both situations can not be handled right now, though. trackFiles = [] trackTitles = [] polyfiles = args.polys if args.statName is not None and args.statLonLat is None : statData = ByName(args.statName)[0] args.statLonLat = (statData['LON'], statData['LAT']) if args.simName is not None : dirName = os.path.join(args.directory, args.simName) simParams = ParamUtils.ReadSimulationParams(os.path.join(dirName, "simParams.conf")) if args.trackRuns is not None : simParams['trackers'] = ExpandTrackRuns(simParams['trackers'], args.trackRuns) trackFiles = [os.path.join(dirName, simParams['result_file'] + '_' + aTracker) for aTracker in simParams['trackers']] trackTitles = simParams['trackers'] if args.truthTrackFile is None : args.truthTrackFile = os.path.join(dirName, simParams['noisyTrackFile']) if args.simTagFile is None : args.simTagFile = os.path.join(dirName, simParams['simTagFile']) trackFiles += args.trackFiles trackTitles += args.trackFiles if args.trackTitles is not None : trackTitles = args.trackTitles if len(trackFiles) == 0 : print "WARNING: No trackFiles given or found!" if args.layout is None : args.layout = (1, len(trackFiles)) if args.figsize is None : args.figsize = plt.figaspect(float(args.layout[0]) / args.layout[1]) if len(trackFiles) < len(polyfiles): raise ValueError("Can not have more polygon files than trackfiles!") trackerData = [FilterMHTTracks(*ReadTracks(trackFile)) for trackFile in trackFiles] polyData = [_load_verts(f, tracks + falarms) for f, (tracks, falarms) in zip(polyfiles, trackerData)] keeperIDs = None if args.simTagFile is not None : simTags = ParamUtils.ReadSimTagFile(args.simTagFile) keeperIDs = ParamUtils.process_tag_filters(simTags, args.filters) if args.statLonLat is not None : for aTracker in trackerData : CoordinateTransform(aTracker[0] + aTracker[1], args.statLonLat[0], args.statLonLat[1]) for polys in polyData: CoordinateTrans_lists(polys, args.statLonLat[0], args.statLonLat[1]) theFig = plt.figure(figsize=args.figsize) grid = AxesGrid(theFig, 111, nrows_ncols=args.layout,# aspect=False, share_all=True, axes_pad=0.45) if args.truthTrackFile is not None : (true_tracks, true_falarms) = FilterMHTTracks(*ReadTracks(args.truthTrackFile)) if args.statLonLat is not None : CoordinateTransform(true_tracks + true_falarms, args.statLonLat[0], args.statLonLat[1]) true_AssocSegs = CreateSegments(true_tracks) true_FAlarmSegs = CreateSegments(true_falarms) if keeperIDs is not None : true_AssocSegs = FilterSegments(keeperIDs, true_AssocSegs) true_FAlarmSegs = FilterSegments(keeperIDs, true_FAlarmSegs) (xLims, yLims, frameLims) = DomainFromTracks(true_tracks + true_falarms) else : true_AssocSegs = None true_FAlarmSegs = None stackedTracks = [] for aTracker in trackerData : stackedTracks += aTracker[0] + aTracker[1] (xLims, yLims, frameLims) = DomainFromTracks(stackedTracks) startFrame = args.startFrame endFrame = args.endFrame tail = args.tail if startFrame is None : startFrame = 0 if endFrame is None : endFrame = frameLims[1] if tail is None : tail = endFrame - startFrame # A common timer for all animations for syncing purposes. theTimer = None if args.radarFile is not None and args.statLonLat is not None : if endFrame >= len(args.radarFile) : # Not enough radar files, so truncate the tracks. endFrame = len(args.radarFile) - 1 files = args.radarFile[startFrame:(endFrame + 1)] radAnim = RadarAnim(theFig, files) theTimer = radAnim.event_source for ax in grid : radAnim.add_axes(ax, alpha=0.6, zorder=0) else : radAnim = None showMap = (args.statLonLat is not None and args.displayMap) if showMap : bmap = Basemap(projection='cyl', resolution='i', suppress_ticks=False, llcrnrlat=yLims[0], llcrnrlon=xLims[0], urcrnrlat=yLims[1], urcrnrlon=xLims[1]) animator = SegAnimator(theFig, startFrame, endFrame, tail, event_source=theTimer, fade=args.fade) for index, (tracks, falarms) in enumerate(trackerData): curAxis = grid[index] if showMap : PlotMapLayers(bmap, mapLayers, curAxis, zorder=0.1) if true_AssocSegs is not None and true_FAlarmSegs is not None : trackAssocSegs = CreateSegments(tracks) trackFAlarmSegs = CreateSegments(falarms) if keeperIDs is not None : trackAssocSegs = FilterSegments(keeperIDs, trackAssocSegs) trackFAlarmSegs = FilterSegments(keeperIDs, trackFAlarmSegs) truthtable = CompareSegments(true_AssocSegs, true_FAlarmSegs, trackAssocSegs, trackFAlarmSegs) l, d = Animate_Segments(truthtable, (startFrame, endFrame), axis=curAxis) else : if keeperIDs is not None : filtFunc = lambda trk : FilterTrack(trk, cornerIDs=keeperIDs) tracks = map(filtFunc, tracks) falarms = map(filtFunc, falarms) CleanupTracks(tracks, falarms) l, d = Animate_PlainTracks(tracks, falarms, (startFrame, endFrame), axis=curAxis) animator._lines.extend(l) animator._lineData.extend(d) #curAxis.set_aspect("equal", 'datalim') #curAxis.set_aspect("equal") curAxis.set_title(trackTitles[index]) if not showMap : curAxis.set_xlabel("X") curAxis.set_ylabel("Y") else : curAxis.set_xlabel("Longitude (degrees)") curAxis.set_ylabel("Latitude (degrees)") polyAnims = [] for ax, verts in zip(grid, polyData): from matplotlib.animation import ArtistAnimation polyAnim = ArtistAnimation(theFig, _to_polygons(polys[startFrame:endFrame + 1], ax), event_source=theTimer) polyAnims.append(polyAnim) if args.xlims is not None and np.prod(grid.get_geometry()) > 0 : grid[0].set_xlim(args.xlims) if args.ylims is not None and np.prod(grid.get_geometry()) > 0 : grid[0].set_ylim(args.ylims) if args.saveImgFile is not None : if radAnim is not None : radAnim = [radAnim] else: radAnim = [] animator.save(args.saveImgFile, extra_anim=radAnim + polyAnims) if args.doShow : plt.show()
def main(args): if args.bw_mode: BW_mode() if len(args.trackFiles) == 0: print "WARNING: No trackFiles given!" if args.trackTitles is None: args.trackTitles = args.trackFiles else: if len(args.trackTitles) != len(args.trackFiles): raise ValueError("The number of TITLEs do not match the" " number of TRACKFILEs.") if args.statName is not None and args.statLonLat is None: statData = ByName(args.statName)[0] args.statLonLat = (statData["LON"], statData["LAT"]) if args.layout is None: args.layout = (1, len(args.trackFiles)) if args.figsize is None: args.figsize = plt.figaspect(float(args.layout[0]) / args.layout[1]) trackerData = [FilterMHTTracks(*ReadTracks(trackFile)) for trackFile in args.trackFiles] if args.statLonLat is not None: for aTracker in trackerData: CoordinateTransform(aTracker[0] + aTracker[1], args.statLonLat[0], args.statLonLat[1]) if args.simTagFiles is None: args.simTagFiles = [None] multiTags = [ReadSimTagFile(fname) if fname is not None else None for fname in args.simTagFiles] if len(trackerData) > len(multiTags): # Very rudimentary broadcasting of multiTags to match trackerData tagMult = max(int(len(trackerData) // len(multiTags)), 1) multiTags = multiTags * tagMult theFig = plt.figure(figsize=args.figsize) grid = AxesGrid(theFig, 111, nrows_ncols=args.layout, aspect=False, share_all=True, axes_pad=0.45) showMap = args.statLonLat is not None and args.displayMap # Can only do this if all other data being displayed will be in # lon/lat coordinates if args.radarFile is not None and args.statLonLat is not None: if len(args.radarFile) > 1 and args.endFrame is not None: args.radarFile = args.radarFile[args.endFrame] else: args.radarFile = args.radarFile[-1] data = LoadRastRadar(args.radarFile) for ax in grid: MakeReflectPPI( data["vals"][0], data["lats"], data["lons"], meth="pcmesh", ax=ax, colorbar=False, axis_labels=False, zorder=0, alpha=0.6, ) MakeTrackPlots( grid, trackerData, args.trackTitles, showMap, endFrame=args.endFrame, tail=args.tail, fade=args.fade, multiTags=multiTags, tag_filters=args.filters, ) if args.xlims is not None and np.prod(grid.get_geometry()) > 0: grid[0].set_xlim(args.xlims) if args.ylims is not None and np.prod(grid.get_geometry()) > 0: grid[0].set_ylim(args.ylims) if args.saveImgFile is not None: theFig.savefig(args.saveImgFile) if args.doShow: plt.show()