def tempo_diagram(vehicledata, ratio, density, lanes): timesteps = vehicledata.shape[0] road = np.zeros((timesteps, lanes, 100), dtype=bool) for t, time in enumerate(vehicledata): for vehicle in time: pos = vehicle[0] lane = vehicle[1] size = vehicle[3] if size == 4: road[t, lane:lane+2, pos-1:pos+1] = 1 else: road[t, lane, pos] = 1 if lanes>1: fig = plt.figure(1) grid = ImageGrid(fig, 111, nrows_ncols = (1, lanes), axes_pad=0.1, ) for i in range(lanes): STD = road[:,i,:] grid[i].imshow(STD, cmap="binary", interpolation="nearest") grid[i].set_title(r"$L_%s$" % i) grid[i].set_xticklabels([]) else: fig = plt.figure(1) STD = road[:,0,:] grid = fig.add_subplot(111) grid.imshow(STD, cmap="binary", interpolation="nearest") grid.set_title(r"$l_%s$" % i) grid.set_xticklabels([]) plt.savefig('CR.%.2f.D%.2f.png' % (ratio, density), bbox_inches="tight")
def _create_subplots(self, kind='', figsize=None, nrows=1, ncols=1, rect=111, cbar_mode='single', squeeze=False, **kwargs): """ :Kwargs: - kind (str, default: '') The kind of plot. For plotting matrices or images (`matplotlib.pyplot.imshow`), choose `matrix`, otherwise leave blank. - figsize (tuple, defaut: None) Size of the figure. - nrows_ncols (tuple, default: (1, 1)) Shape of subplot arrangement. - **kwargs A dictionary of keyword arguments that `matplotlib.ImageGrid` or `matplotlib.pyplot.suplots` accept. Differences: - `rect` (`matplotlib.ImageGrid`) is a keyword argument here - `cbar_mode = 'single'` - `squeeze = False` :Returns: `matplotlib.pyplot.figure` and a grid of axes. """ if 'nrows_ncols' not in kwargs: nrows_ncols = (nrows, ncols) else: nrows_ncols = kwargs['nrows_ncols'] del kwargs['nrows_ncols'] try: num = self.fig.number self.fig.clf() except: num = None if kind == 'matrix': self.fig = self.figure(figsize=figsize, num=num) self.axes = ImageGrid(self.fig, rect, nrows_ncols=nrows_ncols, cbar_mode=cbar_mode, **kwargs ) else: self.fig, self.axes = plt.subplots( nrows=nrows_ncols[0], ncols=nrows_ncols[1], figsize=figsize, squeeze=squeeze, num=num, **kwargs ) self.axes = self.axes.ravel() # turn axes into a list self.kind = kind self.subplotno = -1 # will get +1 after the plot command self.nrows_ncols = nrows_ncols return (self.fig, self.axes)
def plot_color_index(g_i_list=None, i_list=None, list_names=None, limits=None, savedir='./', filename=None, show=True, title='', redshifts=None): ''' Plots amp_functions = tuple of functions ''' # Import external modules import numpy as np import pyfits as pf import matplotlib.pyplot as plt import matplotlib from mpl_toolkits.axes_grid1 import ImageGrid # Set up plot aesthetics plt.clf() plt.rcdefaults() colormap = plt.cm.gist_ncar color_cycle = [colormap(i) for i in np.linspace(0, 0.9, len(i_list))] fontScale = 12 params = { #'backend': .pdf', 'axes.labelsize': fontScale, 'axes.titlesize': fontScale, 'text.fontsize': fontScale, 'legend.fontsize': fontScale * 3 / 4, 'xtick.labelsize': fontScale, 'ytick.labelsize': fontScale, 'font.weight': 500, 'axes.labelweight': 500, 'text.usetex': False, 'figure.figsize': (6, 6), 'axes.color_cycle': color_cycle # colors of different plots } plt.rcParams.update(params) # Create figure fig = plt.figure() grid = ImageGrid(fig, (1, 1, 1), nrows_ncols=(1, 1), ngrids=1, direction='row', axes_pad=1, aspect=False, share_all=True, label_mode='All') colors = ['k', 'b', 'g', 'r', 'c'] linestyles = ['-', '--', '-.', '-', '-'] letters = ['a', 'b'] for i in range(1): ax = grid[i] for i in range(len(i_list)): ax.plot(i_list[i], g_i_list[i], label='%s Gyr' % list_names[i], marker='s') if i < 2: for j, z in enumerate(redshifts): ax.annotate('z=%s' % z, xy=(i_list[i][j], g_i_list[i][j]), textcoords='offset points', xytext=(2, 3), size=fontScale * 0.75) if limits is not None: ax.set_xlim(limits[0], limits[1]) ax.set_ylim(limits[2], limits[3]) # Adjust asthetics ax.set_xlabel(r'$M_i$ (mag)', ) ax.set_ylabel(r'$M_g - M_i$ (mag)') ax.grid(True) ax.legend(loc='bottom right') ax.set_title(title) if filename is not None: plt.savefig(savedir + filename, bbox_inches='tight', dpi=600) if show: fig.show()
def plot_segmask_3cl_input(y_pred, x_in, y_true=None, class_to_plot=(2, 3), input_to_plot=1, input_name='channel 1', xtick_int=50, ytick_int=50, show_plt=True, save_imag=True, imag_name='pred_mask_input', save_as='pdf'): '''Function to plot segmentation mask (3 classes) including 1 input channel, this is a prototype version ==> can be combined with plot_segmask_input''' m_temp = np.argmax(y_pred, axis=-1) + 1 pred_mask = ((m_temp * (m_temp == class_to_plot[0])) * (0.5 / class_to_plot[0])) + ((m_temp * (m_temp == class_to_plot[1])) * (1.0 / class_to_plot[1])) if y_true is not None: # Plot prediction mask, ground truth and input image (1 channel) g_temp = np.argmax(y_true, axis=-1) + 1 gr_truth = ((g_temp * (g_temp == class_to_plot[0])) * (0.5 / class_to_plot[0])) + ( (g_temp * (g_temp == class_to_plot[1])) * (1.0 / class_to_plot[1])) grid_cmap = ['jet', 'gray', 'gray'] grid_imag = [pred_mask, gr_truth, x_in[..., input_to_plot - 1]] grid_title = [r'Segmentation mask', r'Ground truth', input_name] fig_width = 10.5 n_cols = 3 else: # Plot prediction mask and input image (1 channel) grid_cmap = ['jet', 'gray'] grid_imag = [pred_mask, x_in[..., input_to_plot - 1]] grid_title = [r'Segmentation mask', input_name] fig_width = 6.8 n_cols = 2 fig = plt.figure(figsize=(fig_width, 4)) grid = ImageGrid(fig, rect=[0.1, 0.07, 0.85, 0.9], nrows_ncols=(1, n_cols), axes_pad=0.25, share_all=True) for i in range(0, n_cols): ax = grid[i] ax.imshow(grid_imag[i], vmin=0, vmax=1, cmap=grid_cmap[i]) ax.set_xticks(np.arange(0, pred_mask.shape[1] + 1, xtick_int)) ax.set_yticks(np.arange(0, pred_mask.shape[0] + 1, ytick_int)) ax.set_xlabel(r'image width [pixel]') ax.set_ylabel(r'image height [pixel]') ax.set_title(grid_title[i]) if show_plt == True: plt.show() if save_imag == True: plt.savefig(imag_name + '.' + save_as, bbox_inches='tight') if show_plt == False: # Clear memory (or matplotlib history) although the figure # is not shown plt.close()
def diagnostic_plots(model, chainfile, burnin=0.3, channelmaps=True, cornerplot=True, momentmaps=True, bestarray=False, outname='model', vrangecm=None, vrangemom=None, cmap0='RdYlBu_r', cmap1='Spectral_r', cmap2='copper_r', maskvalue=3.0, **kwargs): """ This program will create several diagnostic plots of a given model and MCMC chain. The plots are not paper-ready but are meant to understand how well the MCMC chain ran. Currently the following plots are generated: Channel maps of the data, model and residual as well as a composite file which shows the data with the residual overplotted as contours A corner plot of the given MCMC chain converted to the correct units (uses the corner package) Moment-0, 1 and 2 images of the data and the model. For the moment 0, the residual contours are also plotted in the moment0-data. input and keywords: model: This is a Qube object that contains a defined model. chainfile: This is a file containting an MCMC chain as a numpy.save object. The shape of the object is (nwalkers, nruns, dim+1) where dim is the number of variable parameters in the MCMC chain and the extra slice contains the lnprobability of the given parameters in that link. burnin: burnin fraction of the chain (default = 30%, i.e., 0.3) channelmaps: If true will plot the channels maps of the data, model and residual cornerplot: If true will plot the corner plot of the chain file momentmaps: If true will plot the zeroth, first and second moment of the data and the model bestarray: If true the plots will be generated from the model with the highest probability if false the median parameters will be chosen. outname: The name of the root used to save the pdf figures vrangecm: z-axis range used to plot the channelmaps vrangemom: z-axis range used to plot the moment-0 channel maps cmap0: colormap to use for the plotting of the moment-0 maps cmap1: colormap to use for the plotting of the moment-1 maps cmap2: colormap to use for the plotting of the moment-2 maps maskvalue: value (in sigma) to use to generate the mask in the moment-0 image which is used in higher moment images. default is 3.0. """ # read in the chain data file and load it in the model then regenerate # the model with the wanted values (either median or best) Chain = np.load(chainfile) Chain = Chain[:, int(burnin * Chain.shape[1]):, :-1] Chain = Chain.reshape((-1, Chain.shape[2])) model.get_chainresults(chainfile, burnin=burnin) if not bestarray: model.update_parameters(model.chainpar['MedianArray']) else: model.update_parameters(model.chainpar['BestArray']) model.create_model() # create the data, model and residual cubes dqube = dc(model) mqube = dc(model) mqube.data = model.model rqube = dc(model) rqube.data = model.data - model.model # make the corner plot if cornerplot is True: # convert each parameter to a physically meaningful quantity # units = [] for idx, key in enumerate(model.mcmcmap): # get conversion factors and units for each key if model.initpar[key]['Conversion'] is not None: conversion = model.initpar[key]['Conversion'].value else: conversion = 1. # units.append(model.initpar[key]['Unit'].to_string()) # quick fix for IO if key == 'I0': Chain[:, idx] = Chain[:, idx] * 1E3 Chain[:, idx] = Chain[:, idx] * conversion corner.corner(Chain, labels=model.mcmcmap, quantiles=[0.16, 0.5, 0.84], show_titles=True) plt.savefig(outname + '_cornerplot.pdf', format='pdf', dpi=300) plt.close() # make the channel maps if channelmaps is True: # define some global properties for all plots if vrangecm is None: vrangecm = [np.nanmin(dqube.data), np.nanmax(dqube.data)] sigma = np.sqrt(model.variance[:, 0, 0]) clevels = [np.insert(np.arange(3, 30, 3), 0, -3) * i for i in sigma] # make the channel map for the data create_channelmap(raster=dqube, contour=dqube, clevels=clevels, pdfname=outname + '_datachannelmap.pdf', vrange=vrangecm, **kwargs) # make the channel map for the model create_channelmap(raster=mqube, contour=mqube, clevels=clevels, pdfname=outname + '_modelchannelmap.pdf', vrange=vrangecm, **kwargs) # make the channel map for the residual create_channelmap(raster=rqube, contour=rqube, clevels=clevels, pdfname=outname + '_residualchannelmap.pdf', vrange=vrangecm, **kwargs) # make the channel map for the data with residual contours create_channelmap(raster=dqube, contour=rqube, clevels=clevels, pdfname=outname + '_combinedchannelmap.pdf', vrange=vrangecm, **kwargs) plt.close() # make the moment maps if momentmaps is True: # create the moment-0 images dMom0 = dqube.calculate_moment(moment=0) mMom0 = mqube.calculate_moment(moment=0) rMom0 = rqube.calculate_moment(moment=0) # calculate the Mom0sig of the data and create the contour levels Mom0sig = (np.sqrt(np.nansum(model.variance[:, 0, 0])) * model.__get_velocitywidth__()) clevels = np.insert(np.arange(3, 30, 3), 0, -3) * Mom0sig if vrangemom is None: vrangemom = [-3 * Mom0sig, 11 * Mom0sig] mask = dMom0.mask_region(value=Mom0sig * maskvalue, applymask=False) # create the figure fig = plt.figure(1, (8., 8.)) grid = ImageGrid(fig, 111, nrows_ncols=(2, 2), axes_pad=0., cbar_mode='single', cbar_location='right') # plot the figures standardfig(raster=dMom0, contour=dMom0, clevels=clevels, ax=grid[0], fig=fig, vrange=vrangemom, cmap=cmap0, text='Data', textprop=[dict(size=12)], **kwargs) standardfig(raster=mMom0, contour=mMom0, clevels=clevels, ax=grid[1], fig=fig, vrange=vrangemom, cmap=cmap0, beam=False, text='Model', textprop=[dict(size=12)], **kwargs) standardfig(raster=rMom0, contour=rMom0, clevels=clevels, ax=grid[2], fig=fig, vrange=vrangemom, cmap=cmap0, text='Residual', textprop=[dict(size=12)], **kwargs) standardfig(raster=dMom0, contour=rMom0, clevels=clevels, ax=grid[3], fig=fig, vrange=vrangemom, cmap=cmap0, text='Data with residual contours', textprop=[dict(size=12)], **kwargs) # plot the color bar norm = mpl.colors.Normalize(vmin=vrangemom[0], vmax=vrangemom[1]) cmapo = plt.cm.ScalarMappable(cmap=cmap0, norm=norm) cmapo.set_array([]) cbr = plt.colorbar(cmapo, cax=grid.cbar_axes[0]) cbr.ax.set_ylabel('Moment-0', labelpad=-1) plt.savefig(outname + '_moment0.pdf', format='pdf', dpi=300) plt.close() # create the moment-1 images dsqube = dqube.mask_region(mask=mask) dMom1 = dsqube.calculate_moment(moment=1) msqube = mqube.mask_region(mask=mask) mMom1 = msqube.calculate_moment(moment=1) # create the figure fig = plt.figure(1, (8., 5.)) grid = ImageGrid(fig, 111, nrows_ncols=(1, 2), axes_pad=0., cbar_mode='single', cbar_location='right') # plot the figures vrangemom1 = [np.nanmin(dMom1.data), np.nanmax(dMom1.data)] standardfig(raster=dMom1, ax=grid[0], fig=fig, cmap=cmap1, vrange=vrangemom1, text='Data', textprop=[dict(size=12)], **kwargs) standardfig(raster=mMom1, ax=grid[1], fig=fig, cmap=cmap1, vrange=vrangemom1, beam=False, text='Model', textprop=[dict(size=12)], **kwargs) # plot the color bar norm = mpl.colors.Normalize(vmin=vrangemom1[0], vmax=vrangemom1[1]) cmapo = plt.cm.ScalarMappable(cmap=cmap1, norm=norm) cmapo.set_array([]) cbr = plt.colorbar(cmapo, cax=grid.cbar_axes[0]) cbr.ax.set_ylabel('Moment-1', labelpad=-1) plt.savefig(outname + '_moment1.pdf', format='pdf', dpi=300) plt.close() # create the moment-2 images dsqube = dqube.mask_region(mask=mask) dMom2 = dsqube.calculate_moment(moment=2) msqube = mqube.mask_region(mask=mask) mMom2 = msqube.calculate_moment(moment=2) # create the figure fig = plt.figure(1, (5., 8.)) grid = ImageGrid(fig, 111, nrows_ncols=(1, 2), axes_pad=0., cbar_mode='single', cbar_location='right') # plot the figures vrangemom2 = [np.nanmin(dMom2.data), np.nanmax(dMom2.data)] standardfig(raster=dMom2, ax=grid[0], fig=fig, cmap=cmap2, vrange=vrangemom2, text='Data', textprop=[dict(size=12)], **kwargs) standardfig(raster=mMom2, ax=grid[1], fig=fig, cmap=cmap2, vrange=vrangemom2, beam=False, text='Model', textprop=[dict(size=12)], **kwargs) # plot the color bar norm = mpl.colors.Normalize(vmin=vrangemom2[0], vmax=vrangemom2[1]) cmapo = plt.cm.ScalarMappable(cmap=cmap2, norm=norm) cmapo.set_array([]) cbr = plt.colorbar(cmapo, cax=grid.cbar_axes[0]) cbr.ax.set_ylabel('Moment-2', labelpad=-1) plt.savefig(outname + '_moment2.pdf', format='pdf', dpi=300) plt.close()
def _savegif( self, stems: List[str], imgs: NDArray, masks: NDArray, reconstructed_imgs: NDArray, amaps: NDArray, ) -> None: os.mkdir("results") pbar = tqdm(enumerate( zip(stems, imgs, masks, reconstructed_imgs, amaps)), desc="savegif") for i, (stem, img, mask, reconstructed_img, amap) in pbar: # How to get two subplots to share the same y-axis with a single colorbar # https://stackoverflow.com/a/38940369 grid = ImageGrid( fig=plt.figure(figsize=(16, 4)), rect=111, nrows_ncols=(1, 4), axes_pad=0.15, share_all=True, cbar_location="right", cbar_mode="single", cbar_size="5%", cbar_pad=0.15, ) grid[0].imshow(img, cmap="gray") grid[0].tick_params(labelbottom=False, labelleft=False, bottom=False, left=False) grid[0].set_title("Input Image", fontsize=20) grid[1].imshow(reconstructed_img, cmap="gray") grid[1].tick_params(labelbottom=False, labelleft=False, bottom=False, left=False) grid[1].set_title("Reconstructed Image", fontsize=20) grid[2].imshow(img, cmap="gray") grid[2].imshow(mask, alpha=0.3, cmap="Reds") grid[2].tick_params(labelbottom=False, labelleft=False, bottom=False, left=False) grid[2].set_title("Ground Truth", fontsize=20) grid[3].imshow(img, cmap="gray") im = grid[3].imshow(amap, alpha=0.3, cmap="jet", vmin=0, vmax=1) grid[3].tick_params(labelbottom=False, labelleft=False, bottom=False, left=False) grid[3].cax.toggle_label(True) grid[3].set_title("Anomaly Map", fontsize=20) plt.colorbar(im, cax=grid.cbar_axes[0]) plt.savefig(f"results/{stem}.png", bbox_inches="tight") plt.close() # NOTE(inoue): The gif files converted by PIL or imageio were low-quality. # So, I used the conversion command (ImageMagick) instead. subprocess.run("convert -delay 100 -loop 0 results/*.png result.gif", shell=True)
with torch.no_grad(): reco = hm.reconstruct(samp) reco_im = torch.squeeze(reco).reshape(28, 28) samp_im = torch.squeeze(samp).reshape(28, 28) plt.imshow(samp_im) plt.show() plt.imshow(reco_im) plt.show() # <codecell> with torch.no_grad(): samp = hm.sample(25) samp_im = torch.squeeze(samp).reshape(25, 28, 28) fig = plt.figure(figsize=(10, 10)) grid = ImageGrid( fig, 111, # similar to subplot(111) nrows_ncols=(5, 5), # creates 2x2 grid of axes axes_pad=0.1, # pad between axes in inch. ) for ax, im in zip(grid, samp_im): ax.imshow(im) fig.suptitle('Sample faces drawn from HM') plt.show() # TODO: debug same image problem
def make_gallery(files=None, ims=None, heads=None, outname=None, pfovs=None, scale="lin", norm=None, absmin=None, absmax=None, permax=None, permin=None, papercol=2, ncols=4, pwidth=None, nrows=None, view_as=None, view_px=None, inv=True, cmap='gist_heat', colorbar=None, cbarwidth=0.03, cbarlabel=None, cbarlabelpad=None, xlabel=None, ylabel=None, xlabelpad=None, ylabelpad=None, titles=None, titcols='black', titx=0.5, tity=0.95, titvalign='top', tithalign='center', subtitles=None, subtitcols='black', subtitx=0.5, subtity=0.05, subtitvalign='bottom', subtithalign='center', plotsym=None, contours=None, alphamap=None, cbartickinterval=None, sbarlength=None, sbarunit='as', sbarcol='black', sbarthick='2', sbarpos=[0.1, 0.05], majtickinterval=None, verbose=False, latex=True, texts=None, textcols='black', textx=0.05, texty=0.95, textvalign='top', texthalign='left', textsize=None, replace_NaNs=None, smooth=None, lines=None, latexfontsize=16, axes_pad=0.05): """ MISSING: - implementation to rotate images to North up (and East to the left if desired and necessary The purpose of this routine is to create puplication-quality multiplots of images with a maximum number of customisability. The following parameters can be set either for all images a single variable or as an array with the same number of elements to set the parameters individually INPUT: - flles : list of fits files to be plotted - pfovs : pixel size in arcsec for the images. If not provided then it will be looked for in the fits headers - log : enable logarithmic scaling - norm: provide the index of the image that should be normalised to - permax: set a percentile value which should be used as the maximum in the colormap - papercol: either 1 for a plot fitting into one column in the paper or 2 in case the plot is supposed to go over the full page - ncols: number of columns of images in the multiplot - nrows: number of rows of images in the multiplot - view_as: size of the field of view that should be plotted in arcsec - inv: if set true then the used colormap is inverted - cmap: colormap to be used - xlabelpad, xlabelpad : adjust the position of the x,y axis label - titles: provide titles for the individual images - titcols: colors for the titles - titx, tity: x,y position of the title - titvalign, tithalign: vertical and horizontal alignment of the title text with respect to the given position - subtitles: similar to titles but for another text in the image """ # --- read in the images in a 2D list if ims is None: if type(files) == list: n = len(files) if nrows is None and ncols is not None: nrows = int(np.ceil(1.0 * n / ncols)) if ncols is None and nrows is not None: ncols = int(np.ceil(1.0 * n / nrows)) # --- reshape file list into 2D array rest = n % nrows for i in range((nrows - rest) % nrows): files.append(None) files = np.reshape(files, (nrows, ncols)) elif type(files) == np.ndarray: s = np.shape(files) if len(s) > 1: if nrows is None: nrows = s[0] if ncols is None: ncols = s[1] n = nrows * ncols else: n = 1 nrows = 1 ncols = 1 files = np.reshape(files, (-1, ncols)) # --- create a fill a 2D image array # ims = [[None]*ncols]*nrows # ims = [None]*nrows*ncols ims = [[None] * ncols for _ in range(nrows)] #print(np.shape(ims)) for r in range(nrows): for c in range(ncols): if files[r][c] != None: if files[r][c] != "": i = ncols * r + c ims[r][c] = fits.getdata(files[r][c], ext=0) # images are provides else: s = np.shape(ims) if type(ims) == list or (type(ims) == np.ndarray and len(s) < 4): n = len(ims) if nrows is None and ncols is not None: nrows = int(np.ceil(1.0 * n / ncols)) if ncols is None and nrows is not None: ncols = int(np.ceil(1.0 * n / nrows)) # --- reshape file list into 2D array rest = n % nrows for i in range((nrows - rest) % nrows): ims.append(None) ims_old = np.copy(ims) ims = [[None] * ncols for _ in range(nrows)] for r in range(nrows): for c in range(ncols): i = ncols * r + c ims[r][c] = ims_old[i] elif type(ims) == np.ndarray: if len(s) > 1: if nrows is None: nrows = s[0] if ncols is None: ncols = s[1] n = nrows * ncols else: n = 1 nrows = 1 ncols = 1 ims = [[ims]] if heads is None and files is not None: heads = [[None] * ncols for _ in range(nrows)] for r in range(nrows): for c in range(ncols): if files[r][c] != None: if files[r][c] != "": heads[r][c] = fits.getheader(files[r][c], ext=0) # print(np.shape(heads), len(np.shape(heads))) if heads is not None: if len(np.shape(heads)) == 1: heads = np.full((nrows, ncols), heads, dtype=object) if pfovs is None and heads is not None: pfovs = [[None] * ncols for _ in range(nrows)] for r in range(nrows): for c in range(ncols): if heads[r][c] != None: if heads[r][c] != "": if "PFOV" in heads[r][c]: pfovs[r][c] = float(heads[r][c]['PFOV']) elif "HIERARCH ESO INS PFOV" in heads[r][c]: pfovs[r][c] = float( heads[r][c]["HIERARCH ESO INS PFOV"]) elif "CDELT1" in heads[r][c]: pfovs[r][c] = np.abs(heads[r][c]["CDELT1"]) * 3600 if len(pfovs) == 0: pfovs = None else: pfovs = reshape_input_param(pfovs, nrows, ncols) scale = reshape_input_param(scale, nrows, ncols) permin = reshape_input_param(permin, nrows, ncols) permax = reshape_input_param(permax, nrows, ncols) absmin = reshape_input_param(absmin, nrows, ncols) absmax = reshape_input_param(absmax, nrows, ncols) norm = reshape_input_param(norm, nrows, ncols) smooth = reshape_input_param(smooth, nrows, ncols) titles = reshape_input_param(titles, nrows, ncols) titcols = reshape_input_param(titcols, nrows, ncols) texts = reshape_input_param(texts, nrows, ncols) textcols = reshape_input_param(textcols, nrows, ncols) subtitles = reshape_input_param(subtitles, nrows, ncols) subtitcols = reshape_input_param(subtitcols, nrows, ncols) sbarlength = reshape_input_param(sbarlength, nrows, ncols) if lines is not None: lines = reshape_input_param(lines, nrows, ncols) if verbose: # ny, nx = ims[0].shape print("MAKE_GALLERY: nrows, ncols: ", nrows, ncols) print("MAKE_GALLERY: n: ", n) print("MAKE_GALLERY: pfovs: ", pfovs) print("MAKE_GALLERY: scales: ", scale) # print("MAKE_GALLERY: nx, ny: ", nx, ny) # --- set up the plotting configuration to use latex if latex: mpl.rcdefaults() mpl.rc( 'font', **{ 'family': 'sans-serif', 'serif': ['Computer Modern Serif'], 'sans-serif': ['Helvetica'], 'size': latexfontsize, 'weight': 500, 'variant': 'normal' }) mpl.rc('axes', **{'labelweight': 'normal', 'linewidth': 1.5}) mpl.rc('ytick', **{'major.pad': 8, 'color': 'k'}) mpl.rc('xtick', **{'major.pad': 8}) mpl.rc( 'mathtext', **{ 'default': 'regular', 'fontset': 'cm', 'bf': 'monospace:bold' }) mpl.rc('text', **{'usetex': True}) mpl.rc('text.latex',preamble=r'\usepackage{cmbright},\usepackage{relsize},'+\ r'\usepackage{upgreek}, \usepackage{amsmath}'+\ r'\usepackage{bm}') mpl.rc('contour', **{'negative_linestyle': 'solid'}) # dashed | solid plt.clf() # 14.17 # 6.93 if pwidth == None: if papercol == 1: pwidth = 6.93 elif papercol == 2: pwidth = 14.17 else: pwidth = 7 * papercol subpwidth = pwidth / ncols if verbose: print("MAKE_GALLERY: pwidth, subpwidth: ", pwidth, subpwidth) fig = plt.figure(figsize=(pwidth, nrows * subpwidth)) # fig.subplots_adjust(bottom=0.2) # fig.subplots_adjust(left=0.2) if inv: cmap = cmap + '_r' grid = ImageGrid(fig, 111, nrows_ncols=(nrows, ncols), axes_pad=axes_pad, aspect=True) handles = [] ahandles = [] vmin = np.zeros((nrows, ncols)) vmax = np.zeros((nrows, ncols)) amin = np.zeros((nrows, ncols)) amax = np.zeros((nrows, ncols)) # --- main plotting loop for r in range(nrows): for c in range(ncols): i = ncols * r + c if verbose: print("MAKE_GALLERY: r,c,i", r, c, i) print(" - smooth: ", smooth[r][c]) if np.shape(ims[r][c]) == (): continue if view_as is not None: view_px = view_as / pfovs[r][c] im = _crop_image(ims[r][c], box=np.ceil(view_px / 2.0) * 2 + 2) else: im = np.copy(ims[r][c]) #print(r,c, np.nanmax(ims[r][c]), np.nanmax(im)) if smooth[r][c]: im = gaussian_filter(im, sigma=smooth[r][c], mode='nearest') if norm[r][c]: im[0, 0] = np.nanmax(ims[int(norm[r][c])]) # --- adjust the cut levels if permin[r][c]: vmin[r][c] = np.nanpercentile(im, float(permin[r][c])) else: vmin[r][c] = np.nanmin(im) if permax[r][c]: vmax[r][c] = np.nanpercentile(im, float(permax[r][c])) else: vmax[r][c] = np.nanmax(im) if absmax[r][c] is not None: vmax[r][c] = absmax[r][c] if absmin[r][c] is not None: vmin[r][c] = absmin[r][c] sim = np.copy(im) sim[im < vmin[r][c]] = vmin[r][c] sim[im > vmax[r][c]] = vmax[r][c] if replace_NaNs is not None: idn = np.nonzero(np.isnan(sim)) # print(i, len(idn), idn[0]) if len(idn) == 0: continue elif replace_NaNs == "min": sim[np.ix_(idn[0], idn[1])] = vmin[r][c] elif replace_NaNs == "max": sim[np.ix_(idn[0], idn[1])] = vmax[r][c] else: sim[np.ix_(idn[0], idn[1])] = replace_NaNs # --- logarithmic scaling? if scale[r][c] == "log": sim = np.log10(1000.0 * (sim - vmin[r][c]) / (vmax[r][c] - vmin[r][c]) + 1) if verbose: print("MAKE_GALLERY: scale[r][c]", scale[r][c]) print("MAKE_GALLERY: vmin[r][c], vmax[r][c]", vmin[r][c], vmax[r][c]) handle = grid[i].imshow(sim, cmap=cmap, origin='lower', interpolation='nearest') ny, nx = im.shape if verbose: print("ny, nx:", ny, nx) if pfovs is not None: xmin = (nx / 2 - 1) * pfovs[r][c] xmax = -nx / 2 * pfovs[r][c] ymin = -ny / 2 * pfovs[r][c] ymax = (ny / 2 - 1) * pfovs[r][c] if xlabel is None: xlabel = 'RA offset ["]' if ylabel is None: ylabel = 'DEC offset ["]' else: xmax = (nx / 2 - 1) xmin = -nx / 2 ymin = -ny / 2 ymax = (ny / 2 - 1) if xlabel is None: xlabel = 'x offset [px]' if ylabel is None: ylabel = 'y offset [px]' # print(pfovs[r][c]) # pdb.set_trace() # set the extent of the image extent = [xmin, xmax, ymin, ymax] if verbose: print("extent:", extent) print("x/ylabel:", xlabel, ylabel) handle.set_extent(extent) # --- optionally overplot a second map using transparency if alphamap is not None: uim = np.copy(alphamap[r][c]['im']) # --- determine the levels if alphamap[r][c]['min'] is None: amin[r][c] = np.nanmin(im) elif str(alphamap[r][c]['min']) == 'vmax': amin[r][c] = vmax[r][c] elif str(alphamap[r][c]['min']) == 'vmin': amin[r][c] = vmin[r][c] else: amin[r][c] = np.nanpercentile(im, float(alphamap[r][c]['min'])) if alphamap[r][c]['max'] is None: amax[r][c] = np.nanmax(im) elif str(alphamap[r][c]['max']) == 'vmax': amax[r][c] = vmax[r][c] elif str(alphamap[r][c]['max']) == 'vmax': amax[r][c] = vmin[r][c] else: amax[r][c] = np.nanpercentile(im, float(alphamap[r][c]['max'])) uim[uim < amin[r][c]] = amin[r][c] uim[uim > amax[r][c]] = amax[r][c] cmap2 = _create_alpha_colmap(alphamap[r][c]['mincolor'], alphamap[r][c]['maxcolor'], alphamap[r][c]['minalpha'], alphamap[r][c]['maxalpha']) if alphamap[r][c]['log']: unorm = LogNorm() # uim = np.log10(1000.0 * (uim - amin[r][c]) / # (amax[r][c] - amin[r][c]) + 1) else: unorm = None ahandle = grid[i].imshow(uim, cmap=cmap2, origin='lower', interpolation='nearest', norm=unorm) ahandle.set_extent(extent) ahandles.append(ahandle) # --- optionally draw contours if contours is not None: # --- determine the levels if contours[r][c]['min'] is None: cmin = np.nanmin(im) elif str(contours[r][c]['min']) == 'vmax': cmin = vmax[r][c] elif str(contours[r][c]['min']) == 'vmin': cmin = vmin[r][c] else: cmin = contours[r][c]['min'] if contours[r][c]['max'] is None: cmax = np.nanmax(im) elif str(contours[r][c]['max']) == 'vmax': cmax = vmax[r][c] elif str(contours[r][c]['max']) == 'vmax': cmax = vmin[r][c] else: cmax = contours[r][c]['max'] if contours[r][c]['nstep'] is None: nstep = 10 else: nstep = contours[r][c]['nstep'] if contours[r][c]['stepsize'] is None: stepsize = int(cmax - cmin / nstep) else: stepsize = contours[r][c]['stepsize'] levels = np.arange(cmin, cmax, stepsize) cont = grid[i].contour(im, levels=levels, origin='lower', colors=contours[r][c]['color'], linewidth=contours[r][c]['linewidth'], extent=extent) grid[i].clabel(cont, levels[1::contours[r][c]['labelinterval']], inline=1, fontsize=contours[r][c]['labelsize'], fmt=contours[r][c]['labelfmt']) if pfovs is not None and view_as is not None: xmin = 0.5 * view_as xmax = -0.5 * view_as ymin = -0.5 * view_as ymax = 0.5 * view_as elif view_px is not None: xmin = -0.5 * view_px xmax = 0.5 * view_px ymin = -0.5 * view_px ymax = 0.5 * view_px if verbose: print("xmin,xmax,ymin,ymax:", xmin, xmax, ymin, ymax) grid[i].set_ylim(ymin, ymax) grid[i].set_xlim(xmin, xmax) # --- optionally draw some lines if lines is not None: nlin = len(lines[r][c]["length"]) for l in range(nlin): # --- provided unit for lines is arcsec? # print(lines[r][c]["length"][l]) if lines[r][c]["unit"][l] == "px": lines[r][c]["length"][l] *= pfovs[r][c] lines[r][c]["xoff"][l] *= pfovs[r][c] lines[r][c]["yoff"][l] *= pfovs[r][c] # print(lines[r][c]["length"][l]) ang_rad = (90 - lines[r][c]["pa"][l]) * np.pi / 180.0 hlength = 0.5 * lines[r][c]["length"][l] lx = [ lines[r][c]["xoff"][l] - hlength * np.cos(ang_rad), lines[r][c]["xoff"][l] + hlength * np.cos(ang_rad) ] ly = [ lines[r][c]["yoff"][l] - hlength * np.sin(ang_rad), lines[r][c]["yoff"][l] + hlength * np.sin(ang_rad) ] grid[i].plot(lx, ly, color=lines[r][c]["color"][l], linestyle=lines[r][c]["style"][l], marker='', linewidth=lines[r][c]["thick"][l], alpha=lines[r][c]["alpha"][l]) # theta1 = lines[r][c]["pa"][l] - 10 # theta2 = lines[r][c]["pa"][l] + 10 # --- angular error bars if lines[r][c]["pa_err"][l] > 0: theta1 = 90 - lines[r][c]["pa"][l] - lines[r][c][ "pa_err"][l] theta2 = 90 - lines[r][c]["pa"][l] + lines[r][c][ "pa_err"][l] parc = Arc( (lines[r][c]["xoff"][l], lines[r][c]["yoff"][l]), width=2 * hlength, height=2 * hlength, color=lines[r][c]["color"][l], linewidth=lines[r][c]["thick"][l], angle=0, theta1=theta1, theta2=theta2, alpha=lines[r][c]["alpha"][l]) grid[i].add_patch(parc) theta1 = -90 - lines[r][c]["pa"][l] - lines[r][c][ "pa_err"][l] theta2 = -90 - lines[r][c]["pa"][l] + lines[r][c][ "pa_err"][l] parc = Arc( (lines[r][c]["xoff"][l], lines[r][c]["yoff"][l]), width=2 * hlength, height=2 * hlength, color=lines[r][c]["color"][l], linewidth=lines[r][c]["thick"][l], angle=0, theta1=theta1, theta2=theta2, alpha=lines[r][c]["alpha"][l]) grid[i].add_patch(parc) # --- optionally draw some symbols if plotsym is not None: # if hasattr(plotsym[r][c]['x'], "__len__"): nsymplots = len(plotsym[r][c]['x']) print('nsymplots: ', nsymplots) for j in range(nsymplots): # grid[i].plot(plotsym[r][c]['x'][j], plotsym[r][c]['y'][j], # marker=plotsym[r][c]['marker'][j], # markersize=plotsym[r][c]['markersize'][j], # markeredgewidth=plotsym[r][c]['markeredgewidth'][j], # fillstyle=plotsym[r][c]['fillstyle'][j], # alpha=plotsym[r][c]['alpha'][j], # markerfacecolor=plotsym[r][c]['markerfacecolor'][j], # markeredgecolor=plotsym[r][c]['markeredgecolor'][j]) if plotsym[r][c]['marker'][j] == 'ch': marker = _crosshair_marker() elif plotsym[r][c]['marker'][j] == 'ch45': marker = _crosshair_marker(pa=45) else: marker = plotsym[r][c]['marker'][j] grid[i].scatter(plotsym[r][c]['x'][j], plotsym[r][c]['y'][j], marker=marker, s=plotsym[r][c]['size'][j], linewidths=plotsym[r][c]['linewidth'][j], alpha=plotsym[r][c]['alpha'][j], facecolor=plotsym[r][c]['color'][j], edgecolors=plotsym[r][c]['edgecolor'][j], linestyle=plotsym[r][c]['linestyle'][j]) if plotsym[r][c]['label'][j] is not None: grid[i].text( plotsym[r][c]['x'][j][0] + plotsym[r][c]['labelxoffset'][j], plotsym[r][c]['y'][j][0] + plotsym[r][c]['labelyoffset'][j], plotsym[r][c]['label'][j], color=plotsym[r][c]['labelcolor'][j], verticalalignment=plotsym[r][c]['labelvalign'][j], horizontalalignment=plotsym[r][c]['labelhalign'] [j]) # --- ticks if majtickinterval is None: majorLocator = MultipleLocator(1) if ymax - ymin < 2: majorLocator = MultipleLocator(0.5) minorLocator = AutoMinorLocator(5) elif (ymax - ymin > 10) & (ymax - ymin <= 20): majorLocator = MultipleLocator(2) elif (ymax - ymin > 20) & (ymax - ymin <= 100): majorLocator = MultipleLocator(5) elif ymax - ymin > 100: majorLocator = MultipleLocator(10) else: majorLocator = MultipleLocator(majtickinterval) minorLocator = AutoMinorLocator(10) if ymax - ymin < 2: minorLocator = AutoMinorLocator(5) grid[i].xaxis.set_minor_locator(minorLocator) grid[i].xaxis.set_major_locator(majorLocator) grid[i].yaxis.set_minor_locator(minorLocator) grid[i].yaxis.set_major_locator(majorLocator) grid[i].yaxis.set_tick_params(width=1.5, which='both') grid[i].xaxis.set_tick_params(width=1.5, which='both') grid[i].xaxis.set_tick_params(length=6) grid[i].yaxis.set_tick_params(length=6) grid[i].xaxis.set_tick_params(length=3, which='minor') grid[i].yaxis.set_tick_params(length=3, which='minor') # --- text if titles[r][c]: # pdb.set_trace() if verbose: print("titles[r][c]", titles[r][c]) grid[i].set_title(titles[r][c], x=titx, y=tity, color=titcols[r][c], verticalalignment=titvalign, horizontalalignment=tithalign) if subtitles[r][c]: if verbose: print("subtitles[r][c]", subtitles[r][c]) grid[i].text(subtitx, subtity, subtitles[r][c], color=subtitcols[r][c], transform=grid[i].transAxes, verticalalignment=subtitvalign, horizontalalignment=subtithalign) if texts[r][c]: if verbose: print("texts[r][c]", texts[r][c]) grid[i].text(textx, texty, texts[r][c], color=textcols[r][c], transform=grid[i].transAxes, verticalalignment=textvalign, horizontalalignment=texthalign, fontsize=textsize) # --- scale bar for size comparison if sbarlength[r][c]: if sbarunit == 'px': sbarlength[r][c] = sbarlength[r][c] * pfovs[r][c] sx = [ sbarpos[0] - 0.5 * sbarlength[r][c], sbarpos[0] + 0.5 * sbarlength[r][c] ] sy = [sbarpos[1], sbarpos[1]] grid[i].plot(sx, sy, linewidth=sbarthick, color=sbarcol, transform=grid[i].transAxes) handles.append(handle) # --- get the extent of the largest box containing all the axes/subplots extents = np.array([bla.get_position().extents for bla in grid]) bigextents = np.empty(4) bigextents[:2] = extents[:, :2].min(axis=0) bigextents[2:] = extents[:, 2:].max(axis=0) # --- distance between the external axis and the text if xlabelpad is None and papercol == 2: xlabelpad = 0.03 elif xlabelpad is None and papercol == 1: xlabelpad = 0.15 elif xlabelpad is None: xlabelpad = 0.15 / papercol * 0.5 if ylabelpad is None and papercol == 2: ylabelpad = 0.055 elif ylabelpad is None and papercol == 1: ylabelpad = 0.1 elif ylabelpad is None: ylabelpad = 0.1 / papercol if verbose: print("xlabelpad,ylabelpad: ", xlabelpad, ylabelpad) # --- text to mimic the x and y label. The text is positioned in # the middle fig.text((bigextents[2] + bigextents[0]) / 2, bigextents[1] - xlabelpad, xlabel, horizontalalignment='center', verticalalignment='bottom') fig.text(bigextents[0] - ylabelpad, (bigextents[3] + bigextents[1]) / 2, ylabel, rotation='vertical', horizontalalignment='left', verticalalignment='center') # --- now colorbar business: if colorbar is not None: # first draw the figure, such that the axes are positionned fig.canvas.draw() #create new axes according to coordinates of image plot trans = fig.transFigure.inverted() if colorbar == 'row': for i in range(nrows): g = grid[i * ncols - 1].bbox.transformed(trans) if alphamap[ncols, i] is not None: height = 0.5 * g.height pos = [ g.x1 + g.width * 0.02, g.y0 + height, g.width * cbarwidth, height ] cax = fig.add_axes(pos) if alphamap[ncols, i]['log']: formatter = LogFormatter(10, labelOnlyBase=False) cb = plt.colorbar(ticks=[ 0.1, 0.2, 0.5, 1, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000 ], format=formatter, mappable=ahandles[i * ncols - 1], cax=cax) else: cb = plt.colorbar(mappable=ahandles[i * ncols - 1], cax=cax) #majorLocator = MultipleLocator(alphamap[i*ncols - 1]['cbartickinterval']) #cb.ax.yaxis.set_major_locator(majorLocator) # if alphamap[i*ncols - 1]['log']: # oldlabels = cb.ax.get_yticklabels() # oldlabels = np.array([float(x.get_text().replace('$','')) for x in oldlabels]) # newlabels = 0.001*(10.0**oldlabels -1)*(amax[i*ncols - 1] - amin[i*ncols - 1]) + amin[i*ncols - 1] # newlabels = [str(x)[:4] for x in newlabels] # cb.ax.set_yticklabels(newlabels) else: height = g.height # --- pos = [left, bottom, width, height] pos = [g.x1 + g.width * 0.02, g.y0, g.width * cbarwidth, height] cax = fig.add_axes(pos) if (vmax[ncols, i] - vmin[ncols, i]) < 10: decimals = 1 else: decimals = 0 if (vmax[ncols, i] - vmin[ncols, i]) < 1: decimals = 2 ticks = np.arange(vmin[ncols, i] + 1.0 / 10**decimals, vmax[ncols, i], cbartickinterval) ticks = np.round(ticks, decimals=decimals) cb = plt.colorbar(ticks=ticks, mappable=handles[i * ncols - 1], cax=cax) # majorLocator = MultipleLocator(cbartickinterval) # cb.ax.yaxis.set_major_locator(majorLocator) if scale[ncols, i] == "log": oldlabels = cb.ax.get_yticklabels() oldlabels = np.array( [float(x.get_text().replace('$', '')) for x in oldlabels]) newlabels = 0.001 * (10.0**oldlabels - 1) * ( vmax[i * ncols - 1] - vmin[ncols, i]) + vmin[ncols, i] newlabels = [str(x)[:4] for x in newlabels] cb.ax.set_yticklabels(newlabels) elif colorbar == 'single': gb = grid[-1].bbox.transformed(trans) gt = grid[0].bbox.transformed(trans) if alphamap is not None: height = 0.5 * (gt.y1 - gb.y0) pos = [ gb.x1 + gb.width * 0.02, gb.y0 + height, gb.width * cbarwidth, height ] cax = fig.add_axes(pos) # cb = plt.colorbar(mappable=ahandles[0], cax=cax) # majorLocator = MultipleLocator(alphamap[0]['cbartickinterval']) # cb.ax.yaxis.set_major_locator(majorLocator) # if alphamap[0]['log']: # oldlabels = cb.ax.get_yticklabels() # oldlabels = np.array([float(x.get_text().replace('$','')) for x in oldlabels]) # newlabels = 0.001*(10.0**oldlabels -1)*(amax[0] - amin[0]) + amin[0] # newlabels = [str(x)[:4] for x in newlabels] # cb.ax.set_yticklabels(newlabels) if alphamap[0]['log']: formatter = LogFormatter(10, labelOnlyBase=False) cb = plt.colorbar(ticks=[ 0.1, 0.2, 0.5, 1, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000 ], format=formatter, mappable=ahandles[0], cax=cax) else: cb = plt.colorbar(mappable=ahandles[0], cax=cax) else: height = (gt.y1 - gb.y0) # --- pos = [left, bottom, width, height] pos = [gb.x1 + gb.width * 0.02, gb.y0, gb.width * cbarwidth, height] cax = fig.add_axes(pos) cb = plt.colorbar(mappable=handles[0], cax=cax) # majorLocator = MultipleLocator(cbartickinterval) # cb.ax.yaxis.set_major_locator(majorLocator) if scale[0, 0] == "log": oldlabels = cb.ax.get_yticklabels() oldlabels = np.array( [float(x.get_text().replace('$', '')) for x in oldlabels]) newlabels = 0.001 * (10.0**oldlabels - 1) * (vmax[0] - vmin[0]) + vmin[0] newlabels = [str(x)[:4] for x in newlabels] cb.ax.set_yticklabels(newlabels) if cbarlabel is not None: if cbarlabelpad is None: if papercol == 1: cbarlabelpad = 0.15 else: cbarlabelpad = 0.08 fig.text(bigextents[2] + cbarlabelpad, (bigextents[3] + bigextents[1]) / 2, cbarlabel, rotation='vertical', horizontalalignment='right', verticalalignment='center') if outname: plt.savefig(outname, bbox_inches='tight', pad_inches=0.1) plt.close(fig) else: plt.show() if latex: mpl.rcdefaults()
for category in CATEGORIES: print('{}: {} images'.format(category, len(os.listdir(os.path.join(train_dir, category))))) print("-" * 27) # create a dataframe for dataset including filename_path, catagory and id train = [] for category_id, category in enumerate(CATEGORIES): for file in os.listdir(os.path.join(train_dir, category)): train.append(['../data/train/{}/{}'.format(category, file), category, category_id]) train = pd.DataFrame(train, columns=['file', 'category','category_id']) # Show the sample images fig = plt.figure(1, figsize=(NumCatergories, NumCatergories)) grid = ImageGrid(fig, 111, nrows_ncols=(NumCatergories, NumCatergories), axes_pad=0.05) i = 0 for category_id, category in enumerate(CATEGORIES): for filepath in train[train['category'] == category]['file'].values[:NumCatergories]: ax = grid[i] img = read_img(filepath) ax.imshow(img) ax.axis('off') if i % NumCatergories == NumCatergories - 1: ax.text(250, 112, filepath.split('/')[3], verticalalignment='center') i += 1 print("display samples") plt.pause(3) plt.savefig('../results/sample_images.png', transparent=True)
def test_calc_likelihoods_2(): from numpy.testing import assert_array_almost_equal from numpy.testing import assert_almost_equal from myimage_analysis import calculate_nhi import matplotlib.pyplot as plt from matplotlib import cm from mpl_toolkits.axes_grid1 import ImageGrid av_image = np.array([[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 2, 1, 0], [np.nan, 1, 1, 1, 0], [0, 0, 0, 0, 0]]) #av_image_error = np.random.normal(0.1, size=av_image.shape) av_image_error = 0.1 * np.ones(av_image.shape) #nhi_image = av_image + np.random.normal(0.1, size=av_image.shape) hi_cube = np.zeros((5, 5, 5)) # make inner channels correlated with av hi_cube[:, :, :] = np.array([ [ [1., 0., 0., 0., 0.], [np.nan, 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [1., 0., 0., 0., 10.], ], [ [0., 0., 0., 0., 0.], [0., 0., 2., 0., 0.], [0., 0., 4., 0., 0.], [0., 0., 2., 0., 0.], [0., 0., 0., 0., 0.], ], [ [0., 0., 0., 0., 0.], [0., 0., 0., 2., 0.], [0., 0., 0., 2., 0.], [0., 0., 0., 2., np.nan], [0., 0., 0., 0., 0.], ], [ [0., 0., 0., 0., 0.], [0., 2., 0., 0., 0.], [0., 2., 0., 0., 0.], [0., 2., 0., 0., 0.], [0., 0., 0., 0., 0.], ], [ [0., 0., 0., 0., 0.], [0., 0., 0., 0., np.nan], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [1., 0., 0., 0., 0.2], ], ]) if 1: fig = plt.figure(figsize=(4, 4)) imagegrid = ImageGrid(fig, (1, 1, 1), nrows_ncols=(1, 5), ngrids=5, cbar_mode="single", cbar_location='top', cbar_pad="2%", cbar_size='3%', axes_pad=0.1, aspect=True, label_mode='L', share_all=True) cmap = cm.get_cmap('Greys', 5) for i in xrange(5): im = imagegrid[i].imshow( hi_cube[i, :, :], origin='lower', #aspect='auto', cmap=cmap, interpolation='none', vmin=0, vmax=4) #cb = imagegrid[i].cax.colorbar(im) cbar = imagegrid.cbar_axes[0].colorbar(im) #plt.title('HI Cube') plt.savefig('/usr/users/ezbc/Desktop/hi_cube.png') # make edge channels totally uncorrelated #hi_cube[(0, 4), :, :] = np.arange(0, 25).reshape(5,5) #hi_cube[(0, 4), :, :] = - np.ones((5,5)) hi_vel_axis = np.arange(0, 5, 1) # add intercept intercept_answer = 0.9 av_image = av_image + intercept_answer if 1: fig = plt.figure(figsize=(4, 4)) params = { 'figure.figsize': (1, 1), #'figure.titlesize': font_scale, } plt.rcParams.update(params) imagegrid = ImageGrid(fig, (1, 1, 1), nrows_ncols=(1, 1), ngrids=1, cbar_mode="single", cbar_location='top', cbar_pad="2%", cbar_size='3%', axes_pad=0.1, aspect=True, label_mode='L', share_all=True) cmap = cm.get_cmap('Greys', 5) im = imagegrid[0].imshow( av_image, origin='lower', #aspect='auto', cmap=cmap, interpolation='none', vmin=0, vmax=4) #cb = imagegrid[i].cax.colorbar(im) cbar = imagegrid.cbar_axes[0].colorbar(im) #plt.title('HI Cube') plt.savefig('/usr/users/ezbc/Desktop/av.png') width_grid = np.arange(0, 5, 1) dgr_grid = np.arange(0, 1, 0.1) intercept_grid = np.arange(-1, 1, 0.1) vel_center = 2 results = \ cloudpy._calc_likelihoods( hi_cube=hi_cube / 1.832e-2, hi_vel_axis=hi_vel_axis, vel_center=vel_center, av_image=av_image, av_image_error=av_image_error, width_grid=width_grid, dgr_grid=dgr_grid, intercept_grid=intercept_grid, ) dgr_answer = 1 / 2.0 width_answer = 2 width = results['width_max'] dgr = results['dgr_max'] intercept = results['intercept_max'] print width if 0: width = width_answer intercept = intercept_answer dgr = dgr_answer vel_range = (vel_center - width / 2.0, vel_center + width / 2.0) nhi_image = calculate_nhi(cube=hi_cube, velocity_axis=hi_vel_axis, velocity_range=vel_range) / 1.823e-2 if 1: fig = plt.figure(figsize=(4, 4)) imagegrid = ImageGrid(fig, (1, 1, 1), nrows_ncols=(1, 1), ngrids=1, cbar_mode="single", cbar_location='top', cbar_pad="2%", cbar_size='3%', axes_pad=0.1, aspect=True, label_mode='L', share_all=True) cmap = cm.get_cmap('Greys', 5) im = imagegrid[0].imshow( nhi_image, origin='lower', #aspect='auto', cmap=cmap, interpolation='none', vmin=0, vmax=4) #cb = imagegrid[i].cax.colorbar(im) cbar = imagegrid.cbar_axes[0].colorbar(im) #plt.title('HI Cube') plt.savefig('/usr/users/ezbc/Desktop/nhi.png') if 1: fig = plt.figure(figsize=(4, 4)) imagegrid = ImageGrid(fig, (1, 1, 1), nrows_ncols=(1, 1), ngrids=1, cbar_mode="single", cbar_location='top', cbar_pad="2%", cbar_size='3%', axes_pad=0.1, aspect=True, label_mode='L', share_all=True) cmap = cm.get_cmap('Greys', 5) im = imagegrid[0].imshow( nhi_image * dgr + intercept, origin='lower', #aspect='auto', cmap=cmap, interpolation='none', vmin=0, vmax=4) #cb = imagegrid[i].cax.colorbar(im) cbar = imagegrid.cbar_axes[0].colorbar(im) #plt.title('HI Cube') plt.savefig('/usr/users/ezbc/Desktop/av_model.png') print('residuals = ') print(av_image - (nhi_image * dgr + intercept)) print('dgr', dgr) print('intercept', intercept) print('width', width) assert_almost_equal(results['intercept_max'], intercept_answer) assert_almost_equal(results['dgr_max'], dgr_answer) assert_almost_equal(results['width_max'], width_answer)
def plot_tsne_selection_grid(z_pos, x_pos, z_neg, vmin, vmax, fig_path, labels=None, fig_size=(9, 9), g_j=7, s=.5, suffix='png'): ncol = x_pos.shape[1] g_i = ncol // g_j if (ncol % g_j == 0) else ncol // g_j + 1 if labels is None: labels = [str(a) for a in np.range(ncol)] fig = plt.figure(figsize=fig_size) fig.clf() grid = ImageGrid( fig, 111, nrows_ncols=(g_i, g_j), ngrids=ncol, aspect=True, direction="row", axes_pad=(0.15, 0.5), add_all=True, label_mode="1", share_all=True, cbar_location="top", cbar_mode="each", cbar_size="8%", cbar_pad="5%", ) for seq_index in range(ncol): ax = grid[seq_index] ax.text(0, .92, labels[seq_index], horizontalalignment='center', transform=ax.transAxes, size=20, weight='bold') a = x_pos[:, seq_index] ax.scatter(z_neg[:, 0], z_neg[:, 1], s=s, marker='o', c='lightgray', alpha=0.5, edgecolors='face') im = ax.scatter(z_pos[:, 0], z_pos[:, 1], s=s, marker='o', c=a, cmap=cm.jet, edgecolors='face', vmin=vmin[seq_index], vmax=vmax[seq_index]) ax.cax.colorbar(im) clean_axis(ax) ax.grid(False) plt.savefig('.'.join([fig_path, suffix]), format=suffix) plt.clf() plt.close()
def plot_mass2light( ages=None, m2l=None, limits=None, savedir='./', filename=None, show=True, title='', ): ''' Plots amp_functions = tuple of functions ''' # Import external modules import numpy as np import pyfits as pf import matplotlib.pyplot as plt import matplotlib from mpl_toolkits.axes_grid1 import ImageGrid # Set up plot aesthetics plt.clf() plt.rcdefaults() colormap = plt.cm.gist_ncar color_cycle = [colormap(i) for i in np.linspace(0, 0.9, len(ages))] fontScale = 12 params = { #'backend': .pdf', 'axes.labelsize': fontScale, 'axes.titlesize': fontScale, 'text.fontsize': fontScale, 'legend.fontsize': fontScale * 3 / 4, 'xtick.labelsize': fontScale, 'ytick.labelsize': fontScale, 'font.weight': 500, 'axes.labelweight': 500, 'text.usetex': False, 'figure.figsize': (6, 6), 'axes.color_cycle': color_cycle # colors of different plots } plt.rcParams.update(params) # Create figure fig = plt.figure() grid = ImageGrid(fig, (1, 1, 1), nrows_ncols=(1, 1), ngrids=1, direction='row', axes_pad=1, aspect=False, share_all=True, label_mode='All') colors = ['k', 'b', 'g', 'r', 'c'] linestyles = ['-', '--', '-.', '-', '-'] letters = ['a', 'b'] for i in range(1): ax = grid[i] ax.plot(ages, m2l, marker='s', color='k', markersize=3 #label = 'Age = %s Gyr' % ages[i], ) ax.axhline(y=4.83 / 4.64, xmin=-1, xmax=100, color='k') ax.annotate(r'$M_\odot / L_\odot$', xy=(10, 4.83 / 4.64), textcoords='offset points', xytext=(2, 3)) if limits is not None: ax.set_xlim(limits[0], limits[1]) ax.set_ylim(limits[2], limits[3]) # Adjust asthetics #ax.set_xscale('log') ax.set_xlabel(r'Age (Gyr)', ) ax.set_ylabel(r'$M / L (M_\odot / L_\odot$)') ax.grid(True) ax.legend(loc='upper right') ax.set_title(title) if filename is not None: plt.savefig(savedir + filename, bbox_inches='tight', dpi=600) if show: fig.show()
def plot_fluxes(wavelengths=None, flux_list=None, ages=None, metals=None, limits=None, savedir='./', filename=None, show=True, title='', log_scale=(1, 1), normalized=True, attenuations=None, age_unit='Gyr', balmer_line=False): ''' Plots amp_functions = tuple of functions ''' # Import external modules import numpy as np import pyfits as pf import matplotlib.pyplot as plt import matplotlib from mpl_toolkits.axes_grid1 import ImageGrid # Set up plot aesthetics plt.clf() plt.rcdefaults() colormap = plt.cm.gist_ncar color_cycle = [colormap(i) for i in np.linspace(0, 0.9, len(flux_list))] fontScale = 12 params = { #'backend': .pdf', 'axes.labelsize': fontScale, 'axes.titlesize': fontScale, 'text.fontsize': fontScale, 'legend.fontsize': fontScale * 3 / 4, 'xtick.labelsize': fontScale, 'ytick.labelsize': fontScale, 'font.weight': 500, 'axes.labelweight': 500, 'text.usetex': False, 'figure.figsize': (6, 6), 'axes.color_cycle': color_cycle # colors of different plots } plt.rcParams.update(params) # Create figure fig = plt.figure() grid = ImageGrid(fig, (1, 1, 1), nrows_ncols=(1, 1), ngrids=1, direction='row', axes_pad=1, aspect=False, share_all=True, label_mode='All') colors = ['k', 'b', 'g', 'r', 'c'] linestyles = ['-', '--', '-.', '-', '-'] letters = ['a', 'b'] for i in range(1): ax = grid[i] if balmer_line: lines = [6560, 4861, 4341, 4102, 3970, 3889, 3835, 3646] for line in lines: ax.axvline(x=line, ymin=0, ymax=1e10, color='k', alpha=0.5) for i, fluxes in enumerate(flux_list): if ages is not None and metals is None: ax.plot(wavelengths, fluxes, label='Age = %.1f %s' % (ages[i], age_unit)) elif metals is not None and ages is None: ax.plot( wavelengths, fluxes, label=r'Z = %s ' % metals[i], ) elif ages is not None and metals is not None: ax.plot(wavelengths, fluxes, label = 'Age = %.1f %s, Z = %s ' % \ (ages[i], age_unit, metals[i]), ) elif attenuations is not None: ax.plot( wavelengths, fluxes, label=r'$A_V = $ %s' % attenuations[i], ) if limits is not None: ax.set_xlim(limits[0], limits[1]) ax.set_ylim(limits[2], limits[3]) # Adjust asthetics if log_scale[0]: ax.set_xscale('log') if log_scale[1]: ax.set_yscale('log') ax.set_xlabel(r'$\lambda (\AA$)', ) if normalized: ax.set_ylabel(r'$f_\lambda / f_{5500 \AA}$') else: ax.set_ylabel(r'$f_\lambda d\lambda$') ax.grid(True) ax.legend(loc='upper right') ax.set_title(title) if filename is not None: plt.savefig(savedir + filename, bbox_inches='tight', dpi=600) if show: fig.show()
def plot_mags( wavelengths=None, mag_list=None, ages=None, limits=None, savedir='./', filename=None, show=True, title='', ): ''' Plots amp_functions = tuple of functions ''' # Import external modules import numpy as np import pyfits as pf import matplotlib.pyplot as plt import matplotlib from mpl_toolkits.axes_grid1 import ImageGrid # Set up plot aesthetics plt.clf() plt.rcdefaults() colormap = plt.cm.gist_ncar color_cycle = [colormap(i) for i in np.linspace(0, 0.9, len(mag_list))] fontScale = 12 params = { #'backend': .pdf', 'axes.labelsize': fontScale, 'axes.titlesize': fontScale, 'text.fontsize': fontScale, 'legend.fontsize': fontScale * 3 / 4, 'xtick.labelsize': fontScale, 'ytick.labelsize': fontScale, 'font.weight': 500, 'axes.labelweight': 500, 'text.usetex': False, 'figure.figsize': (6, 6), 'axes.color_cycle': color_cycle # colors of different plots } plt.rcParams.update(params) # Create figure fig = plt.figure() grid = ImageGrid(fig, (1, 1, 1), nrows_ncols=(1, 1), ngrids=1, direction='row', axes_pad=1, aspect=False, share_all=True, label_mode='All') colors = ['k', 'b', 'g', 'r', 'c'] linestyles = ['-', '--', '-.', '-', '-'] letters = ['a', 'b'] for i in range(1): ax = grid[i] for i, mags in enumerate(mag_list): ax.plot( wavelengths, mags, label='Age = %s Gyr' % ages[i], ) # filters filters = ['U', 'B', 'V', 'R', 'I', 'J', 'H', 'K', 'NUV', 'FUV'] filter_centers = [ 3630, 4450, 5510, 6580, 8060, 12200, 16300, 21900, 2274, 1542, ] for j in range(len(filters)): ax.axvline(x=filter_centers[j], ymin=0, ymax=1e10, color='k') ax.annotate(filters[j], xy=(filter_centers[j], -15), textcoords='offset points', xytext=(2, 3)) if limits is not None: ax.set_xlim(limits[0], limits[1]) ax.set_ylim(limits[2], limits[3]) # Adjust asthetics ax.set_xscale('log') ax.set_xlabel(r'$\lambda$ ($\AA$)', ) ax.set_ylabel(r'$M_{AB}\ d\lambda$ (mag / $\AA$)') #ax.grid(True) ax.legend(loc='upper right') ax.set_title(title) if filename is not None: plt.savefig(savedir + filename, bbox_inches='tight', dpi=600) if show: fig.show()
col = 0 row = 0 while row < data.shape[1]: print('yielding (%d, %d)' % (col, row)) yield (data[col, row]) if col == 0: col = 1 else: col = 0 row += 1 # <codecell> samp_im = data_gen() fig = plt.figure(figsize=(10, 10)) grid = ImageGrid( fig, 111, # similar to subplot(111) nrows_ncols=(5, 2), axes_pad=0.1, ) for ax, im in tqdm(zip(grid, samp_im), total=10): ax.imshow(im) fig.suptitle('Sampled images and their nearest neighbor') # plt.show() plt.savefig('nn_fig.png')
urcrnrlat=80, resolution='l', projection='cyl') #Have consistent longitude definition lonlon[lonlon < 0] = lonlon[lonlon < 0] + 360. x, y = m(lonlon[:, :], latlat[:, :]) fig = plt.figure(figsize=(10, 10)) grid = ImageGrid( fig, 111, # as in plt.subplot(111) nrows_ncols=(2, 2), axes_pad=0.15, share_all=True, cbar_location="bottom", cbar_mode="single", cbar_size="3%", cbar_pad=0.15, label_mode="L", ) seasons = ['DJF', 'MAM', 'JJA', 'SON'] nseas = len(seasons) lonticks = np.arange(7) * 30. latticks = np.arange(5) * 30. - 60 i = 0 for ax in grid:
def plot_power_spectrum(image, title=None, filename_prefix=None, filename_suffix='.png', show=False, savedir='./'): ''' Plots power spectrum derived from a fourier transform of the image. ''' # import external modules import numpy as np from agpy import azimuthalAverage as radial_average from scipy import fftpack import matplotlib.pyplot as plt import matplotlib from mpl_toolkits.axes_grid1 import ImageGrid from matplotlib import cm if 0: plt.close() plt.clf() plt.imshow(image) plt.show() image[np.isnan(image)] = 1e10 # Determine power spectrum # ------------------------------------------------------------------------- # Take the fourier transform of the image. #F1 = fftpack.fft2(np.ma.array(image, mask=np.isnan(image))) F1 = fftpack.fft2(image) # Now shift the quadrants around so that low spatial frequencies are in # the center of the 2D fourier transformed image. F2 = fftpack.fftshift(F1) # Calculate a 2D power spectrum psd2D = np.abs(F2)**2 if 0: plt.close() plt.clf() plt.imshow(psd2D) plt.show() power_spectrum = radial_average(psd2D, interpnan=True) # Write frequency in arcmin freq = fftpack.fftfreq(len(power_spectrum)) freq *= 5.0 # Simulate power spectrum for white noise noise_image = np.random.normal(scale=0.1, size=image.shape) F1 = fftpack.fft2(noise_image) # Now shift the quadrants around so that low spatial frequencies are in # the center of the 2D fourier transformed image. F2 = fftpack.fftshift(F1) # Calculate a 2D power spectrum psd2D_noise = np.abs(F2)**2 power_spectrum_noise = radial_average(psd2D_noise, interpnan=True) # Plot power spectrum 1D # ------------------------------------------------------------------------- # Set up plot aesthetics plt.clf() plt.rcdefaults() colormap = plt.cm.gist_ncar #color_cycle = [colormap(i) for i in np.linspace(0, 0.9, len(flux_list))] font_scale = 12 params = { #'backend': .pdf', 'axes.labelsize': font_scale, 'axes.titlesize': font_scale, 'text.fontsize': font_scale, 'legend.fontsize': font_scale * 3 / 4.0, 'xtick.labelsize': font_scale, 'ytick.labelsize': font_scale, 'font.weight': 500, 'axes.labelweight': 500, 'text.usetex': False, 'figure.figsize': (5, 5), #'axes.color_cycle': color_cycle # colors of different plots } plt.rcParams.update(params) # Create figure instance fig = plt.figure() nrows = 1 ncols = 1 ngrids = 1 imagegrid = ImageGrid( fig, (1, 1, 1), nrows_ncols=(nrows, ncols), ngrids=ngrids, axes_pad=0.25, aspect=False, label_mode='L', share_all=True, #cbar_mode='single', cbar_pad=0.1, cbar_size=0.2, ) ax = imagegrid[0] ax.plot(freq, power_spectrum / np.nanmax(power_spectrum), color='k', linestyle='-', linewidth=1.5, drawstyle='steps-mid', label='Data Residuals') ax.plot(freq, power_spectrum_noise / np.nanmax(power_spectrum_noise), color='r', linestyle='-', linewidth=0.4, drawstyle='steps-mid', label='White Noise Residuals') #ax.set_xscale('log') ax.legend(loc='best') #ax.set_xscale('log') ax.set_yscale('log') ax.set_xlabel('Spatial Frequency [1/arcmin]') ax.set_ylabel('Normalized Power Spectrum') ax.set_xlim(0, 0.4) if title is not None: fig.suptitle(title, fontsize=font_scale) if filename_prefix is not None: plt.savefig(savedir + filename_prefix + '_1D' + filename_suffix, bbox_inches='tight') if show: plt.show() # Plot power spectrum image # ------------------------- # Create figure instance fig = plt.figure() nrows = 1 ncols = 1 ngrids = 1 imagegrid = ImageGrid( fig, (1, 1, 1), nrows_ncols=(nrows, ncols), ngrids=ngrids, axes_pad=0.25, aspect=False, label_mode='L', share_all=True, #cbar_mode='single', cbar_pad=0.1, cbar_size=0.2, ) ax = imagegrid[0] extent = [ -image.shape[0] / 2.0, +image.shape[0] / 2.0, -image.shape[1] / 2.0, +image.shape[1] / 2.0 ] ax.imshow(psd2D, origin='lower', cmap=cm.gist_heat, norm=matplotlib.colors.LogNorm(), extent=extent) #ax.set_xscale('log') #ax.legend(loc='center right') #ax.set_yscale('log') ax.set_xlabel('Spatial Frequency in Right Ascension') ax.set_ylabel('Spatial Frequency in Declination') ax.set_xlim(-20, 20) ax.set_ylim(-20, 20) if title is not None: fig.suptitle(title, fontsize=font_scale) if filename_prefix is not None: plt.savefig(savedir + filename_prefix + '_2D' + filename_suffix, bbox_inches='tight') if show: plt.show()
# Mask with a disc R = R* disc((retina_shape[0],retina_shape[0]), (retina_shape[0]//2,retina_shape[0]//2), retina_shape[0]//2) # Take half-retina R = R[:,retina_shape[1]:] # Project to colliculus SC = R[P[...,0], P[...,1]] fig = plt.figure(figsize=(10,15), facecolor='w') ###################### ax1, ax2 = ImageGrid(fig, 211, nrows_ncols=(1,2), axes_pad=0.5) polar_frame(ax1, legend=True) polar_imshow(ax1, R, vmin=0, vmax=5) logpolar_frame(ax2, legend=True) logpolar_imshow(ax2, SC, vmin=0, vmax=5) ax1.text(1.1, 1.1, u"a", ha="left", va="bottom", fontsize=20, fontweight='bold') #ax1.text(0., -1.28, u"0° 90°", # ha="left", va="bottom", fontsize=10) ################################ ax1, ax2 = ImageGrid(fig, 212, nrows_ncols=(1,2), axes_pad=0.5) polar_frame(ax1, legend=True,reduced=True) ''' zax = zoomed_inset_axes(ax1, 6, loc=1) polar_frame(zax, zoom=True)
def plotImageGrid(images, nrows_ncols=None, extent=None, clim=None, interpolation='none', cmap='gray', imScale=2., cbar=True, titles=None, titlecol=['r', 'y']): import matplotlib.pyplot as plt import matplotlib matplotlib.style.use('ggplot') from mpl_toolkits.axes_grid1 import ImageGrid def add_inner_title(ax, title, loc, size=None, **kwargs): from matplotlib.offsetbox import AnchoredText from matplotlib.patheffects import withStroke if size is None: size = dict(size=plt.rcParams['legend.fontsize'], color=titlecol[0]) at = AnchoredText(title, loc=loc, prop=size, pad=0., borderpad=0.5, frameon=False, **kwargs) ax.add_artist(at) at.txt._text.set_path_effects( [withStroke(foreground=titlecol[1], linewidth=3)]) return at if nrows_ncols is None: tmp = np.int(np.floor(np.sqrt(len(images)))) nrows_ncols = (tmp, np.int(np.ceil(np.float(len(images)) / tmp))) if nrows_ncols[0] <= 0: nrows_ncols[0] = 1 if nrows_ncols[1] <= 0: nrows_ncols[1] = 1 size = (nrows_ncols[1] * imScale, nrows_ncols[0] * imScale) fig = plt.figure(1, size) igrid = ImageGrid( fig, 111, # similar to subplot(111) nrows_ncols=nrows_ncols, direction='row', # creates 2x2 grid of axes axes_pad=0.1, # pad between axes in inch. label_mode="L", # share_all=True, cbar_location="right", cbar_mode="single", cbar_size='7%') extentWasNone = False for i in range(len(images)): ii = images[i] if hasattr(ii, 'computeImage'): ii = ii.computeImage() if hasattr(ii, 'getImage'): ii = ii.getImage() if hasattr(ii, 'getMaskedImage'): ii = ii.getMaskedImage().getImage() if hasattr(ii, 'getArray'): bbox = ii.getBBox() if extent is None: extentWasNone = True extent = (bbox.getBeginX(), bbox.getEndX(), bbox.getBeginY(), bbox.getEndY()) ii = ii.getArray() if cbar and clim is not None: ii = np.clip(ii, clim[0], clim[1]) if extent is not None: ii = ii[extent[0]:extent[1], extent[2]:extent[3]] ii = zscale_image(ii) im = igrid[i].imshow(ii, origin='lower', interpolation=interpolation, cmap=cmap, extent=extent, clim=clim) if cbar: igrid[i].cax.colorbar(im) if titles is not None: # assume titles is an array or tuple of same length as images. t = add_inner_title(igrid[i], titles[i], loc=2) t.patch.set_ec("none") t.patch.set_alpha(0.5) if extentWasNone: extent = None extentWasNone = False return igrid
y_pred = cod2.argmax(1) #y_pred = kmeans.predict(cod2) print(np.unique(y_pred)) cat1 = 4 ind, = np.where(y_pred == cat1) np.random.shuffle(ind) ims = 100 * [None] for j in range(100): ims[j] = (images[ind[j]]) plt.ion() fig = plt.figure(figsize=(10, 10)) grid = ImageGrid(fig, 111, nrows_ncols=(8, 8), axes_pad=0.0, label_mode=None) i = 0 for ax, im in zip(grid, ims): #ax.tick_params(labelbottom=False,labelleft=False) ax.imshow(im[:, :]) ax.set_xticks([-1]) ax.set_yticks([-1]) rounded = [round(num1, 2) for num1 in cod2[ind[i]]] ax.text(0.05, 0.9, str(rounded[0:3]), transform=ax.transAxes, fontsize=5, color=[1, 1, 1]) ax.text(0.1, 0.8,
def event_display(config): config.input_file = config.input_file[0] config.output_dir += ('' if config.output_dir.endswith('/') else '/') if not os.path.isdir(config.output_dir): os.mkdir(config.output_dir) print "Reading request from: " + str(config.input_file) print "output directory: " + str(config.output_dir) wl = open(config.input_file, 'r') lines = wl.readlines() for line in lines: splits = line.split() softmax = splits[0].strip() input_file = splits[1].strip() ev = int(splits[2].strip()) print "now processing " + input_file + " at index " + str(ev) event_class = get_class(input_file) write_dir = config.output_dir + event_class + "_softmax" + str( softmax).split('.')[0] + '_' + str(softmax).split('.')[1] + "/" if not os.path.isdir(write_dir): os.mkdir(write_dir) norm = plt.Normalize() cm = matplotlib.cm.plasma cmaplist = [cm(i) for i in range(cm.N)] cm_cat_pmt_in_module = lsc.from_list('Custom cmap', cmaplist, cm.N) bounds_cat_pmt_in_module = np.linspace(0, 19, 20) norm_cat_pmt_in_module = matplotlib.colors.BoundaryNorm( bounds_cat_pmt_in_module, cm_cat_pmt_in_module.N) cm_cat_module_row = lsc.from_list('Custom cmap', cmaplist, cm.N) bounds_cat_module_row = np.linspace(0, 16, 17) norm_cat_module_row = matplotlib.colors.BoundaryNorm( bounds_cat_module_row, cm_cat_module_row.N) cm_cat_module_col = lsc.from_list('Custom cmap', cmaplist, cm.N) bounds_cat_module_col = np.linspace(0, 40, 41) norm_cat_module_col = matplotlib.colors.BoundaryNorm( bounds_cat_module_col, cm_cat_module_col.N) file = ROOT.TFile(input_file, "read") label = -1 if "_gamma" in input_file: label = 0 elif "_e" in input_file: label = 1 elif "_mu" in input_file: label = 2 elif "_pi0" in input_file: label = 3 else: print "Unknown input file particle type" sys.exit() tree = file.Get("wcsimT") nevent = tree.GetEntries() print "number of entries in the tree: " + str(nevent) geotree = file.Get("wcsimGeoT") print "number of entries in the geometry tree: " + str( geotree.GetEntries()) geotree.GetEntry(0) geo = geotree.wcsimrootgeom num_pmts = geo.GetWCNumPMT() np_pos_x_all_tubes = np.zeros((num_pmts)) np_pos_y_all_tubes = np.zeros((num_pmts)) np_pos_z_all_tubes = np.zeros((num_pmts)) np_pmt_in_module_id_all_tubes = np.zeros((num_pmts)) np_pmt_index_all_tubes = np.arange(num_pmts) np.random.shuffle(np_pmt_index_all_tubes) np_module_index_all_tubes = module_index(np_pmt_index_all_tubes) for i in range(len(np_pmt_index_all_tubes)): pmt_tube_in_module_id = np_pmt_index_all_tubes[i] % 19 np_pmt_in_module_id_all_tubes[i] = pmt_tube_in_module_id pmt = geo.GetPMT(np_pmt_index_all_tubes[i]) np_pos_x_all_tubes[i] = pmt.GetPosition(2) np_pos_y_all_tubes[i] = pmt.GetPosition(0) np_pos_z_all_tubes[i] = pmt.GetPosition(1) np_pos_r_all_tubes = np.hypot(np_pos_x_all_tubes, np_pos_y_all_tubes) r_max = np.amax(np_pos_r_all_tubes) np_wall_indices_ad_hoc = np.unique(np_module_index_all_tubes[np.where( (np_pos_z_all_tubes < 499.0) & (np_pos_z_all_tubes > -499.0))[0]]) np_bottom_indices_ad_hoc = np.unique( np_module_index_all_tubes[np.where( (np_pos_z_all_tubes < -499.0))[0]]) np_top_indices_ad_hoc = np.unique(np_module_index_all_tubes[np.where( (np_pos_z_all_tubes > 499.0))[0]]) np_pos_phi_all_tubes = np.arctan2(np_pos_y_all_tubes, np_pos_x_all_tubes) np_pos_arc_all_tubes = r_max * np_pos_phi_all_tubes np_wall_indices = np.where(is_barrel(np_module_index_all_tubes)) np_top_indices = np.where(is_top(np_module_index_all_tubes)) np_bottom_indices = np.where(is_bottom(np_module_index_all_tubes)) np_pmt_in_module_id_wall_tubes = np_pmt_in_module_id_all_tubes[ np_wall_indices] np_pmt_in_module_id_top_tubes = np_pmt_in_module_id_all_tubes[ np_top_indices] np_pmt_in_module_id_bottom_tubes = np_pmt_in_module_id_all_tubes[ np_bottom_indices] np_pos_x_wall_tubes = np_pos_x_all_tubes[np_wall_indices] np_pos_y_wall_tubes = np_pos_y_all_tubes[np_wall_indices] np_pos_z_wall_tubes = np_pos_z_all_tubes[np_wall_indices] np_pos_x_top_tubes = np_pos_x_all_tubes[np_top_indices] np_pos_y_top_tubes = np_pos_y_all_tubes[np_top_indices] np_pos_z_top_tubes = np_pos_z_all_tubes[np_top_indices] np_pos_x_bottom_tubes = np_pos_x_all_tubes[np_bottom_indices] np_pos_y_bottom_tubes = np_pos_y_all_tubes[np_bottom_indices] np_pos_z_bottom_tubes = np_pos_z_all_tubes[np_bottom_indices] np_wall_row, np_wall_col = row_col( np_module_index_all_tubes[np_wall_indices]) np_pos_phi_wall_tubes = np_pos_phi_all_tubes[np_wall_indices] np_pos_arc_wall_tubes = np_pos_arc_all_tubes[np_wall_indices] fig101 = plt.figure(num=101, clear=True) fig101.set_size_inches(10, 8) ax101 = fig101.add_subplot(111) pos_arc_z_disp_all_tubes = ax101.scatter( np_pos_arc_all_tubes, np_pos_z_all_tubes, c=np_pmt_in_module_id_all_tubes, s=5, cmap=cm_cat_pmt_in_module, norm=norm_cat_pmt_in_module, marker='.') ax101.set_xlabel('arc along the wall') ax101.set_ylabel('z') cb_pos_arc_z_disp_all_tubes = fig101.colorbar(pos_arc_z_disp_all_tubes, ticks=range(20), pad=0.1) cb_pos_arc_z_disp_all_tubes.set_label("pmt in module") fig101.savefig(write_dir + "pos_arc_z_disp_all_tubes.pdf") fig102 = plt.figure(num=102, clear=True) fig102.set_size_inches(10, 8) ax102 = fig102.add_subplot(111) pos_x_y_disp_all_tubes = ax102.scatter(np_pos_x_all_tubes, np_pos_y_all_tubes, c=np_pmt_in_module_id_all_tubes, s=5, cmap=cm_cat_pmt_in_module, norm=norm_cat_pmt_in_module, marker='.') ax102.set_xlabel('x') ax102.set_ylabel('y') cb_pos_x_y_disp_all_tubes = fig102.colorbar(pos_x_y_disp_all_tubes, ticks=range(20), pad=0.1) cb_pos_x_y_disp_all_tubes.set_label("pmt in module") fig102.savefig(write_dir + "pos_x_y_disp_all_tubes.pdf") fig103 = plt.figure(num=103, clear=True) fig103.set_size_inches(10, 8) ax103 = fig103.add_subplot(111) pos_arc_z_disp_wall_tubes = ax103.scatter( np_pos_arc_wall_tubes, np_pos_z_wall_tubes, c=np_pmt_in_module_id_wall_tubes, s=5, cmap=cm_cat_pmt_in_module, norm=norm_cat_pmt_in_module, marker='.') ax103.set_xlabel('arc along the wall') ax103.set_ylabel('z') cb_pos_arc_z_disp_wall_tubes = fig103.colorbar( pos_arc_z_disp_wall_tubes, ticks=range(20), pad=0.1) cb_pos_arc_z_disp_wall_tubes.set_label("pmt in module") fig103.savefig(write_dir + "pos_arc_z_disp_wall_tubes.pdf") fig104 = plt.figure(num=104, clear=True) fig104.set_size_inches(10, 8) ax104 = fig104.add_subplot(111) pos_arc_z_disp_wall_tubes = ax104.scatter(np_pos_arc_wall_tubes, np_pos_z_wall_tubes, c=np_wall_row, s=5, cmap=cm_cat_module_row, norm=norm_cat_module_row, marker='.') ax104.set_xlabel('arc along the wall') ax104.set_ylabel('z') cb_pos_arc_z_disp_wall_tubes = fig104.colorbar( pos_arc_z_disp_wall_tubes, ticks=range(16), pad=0.1) cb_pos_arc_z_disp_wall_tubes.set_label("wall module row") fig104.savefig(write_dir + "pos_arc_z_disp_wall_tubes_color_row.pdf") fig105 = plt.figure(num=105, clear=True) fig105.set_size_inches(10, 8) ax105 = fig105.add_subplot(111) pos_arc_z_disp_wall_tubes = ax105.scatter(np_pos_arc_wall_tubes, np_pos_z_wall_tubes, c=np_wall_col, s=5, cmap=cm_cat_module_col, norm=norm_cat_module_col, marker='.') ax105.set_xlabel('arc along the wall') ax105.set_ylabel('z') cb_pos_arc_z_disp_wall_tubes = fig105.colorbar( pos_arc_z_disp_wall_tubes, ticks=range(40), pad=0.1) cb_pos_arc_z_disp_wall_tubes.set_label("wall module column") fig105.savefig(write_dir + "pos_arc_z_disp_wall_tubes_color_col.pdf") fig106 = plt.figure(num=106, clear=True) fig106.set_size_inches(10, 8) ax106 = fig106.add_subplot(111) pos_x_y_disp_top_tubes = ax106.scatter(np_pos_x_top_tubes, np_pos_y_top_tubes, c=np_pmt_in_module_id_top_tubes, s=5, cmap=cm_cat_pmt_in_module, norm=norm_cat_pmt_in_module, marker='.') ax106.set_xlabel('x') ax106.set_ylabel('y') cb_pos_x_y_disp_top_tubes = fig106.colorbar(pos_x_y_disp_top_tubes, ticks=range(20), pad=0.1) cb_pos_x_y_disp_top_tubes.set_label("pmt in module") fig106.savefig(write_dir + "pos_x_y_disp_top_tubes.pdf") fig107 = plt.figure(num=107, clear=True) fig107.set_size_inches(10, 8) ax107 = fig107.add_subplot(111) pos_x_y_disp_bottom_tubes = ax107.scatter( np_pos_x_bottom_tubes, np_pos_y_bottom_tubes, c=np_pmt_in_module_id_bottom_tubes, s=5, cmap=cm_cat_pmt_in_module, norm=norm_cat_pmt_in_module, marker='.') ax107.set_xlabel('x') ax107.set_ylabel('y') cb_pos_x_y_disp_bottom_tubes = fig107.colorbar( pos_x_y_disp_bottom_tubes, ticks=range(20), pad=0.1) cb_pos_x_y_disp_bottom_tubes.set_label("pmt in module") fig107.savefig(write_dir + "pos_x_y_disp_bottom_tubes.pdf") Eth = { 22: 0.786 * 2, 11: 0.786, -11: 0.786, 13: 158.7, -13: 158.7, 111: 0.786 * 4 } tree.GetEvent(ev) wcsimrootsuperevent = tree.wcsimrootevent print "number of sub events: " + str( wcsimrootsuperevent.GetNumberOfEvents()) wcsimrootevent = wcsimrootsuperevent.GetTrigger(0) tracks = wcsimrootevent.GetTracks() energy = [] position = [] direction = [] pid = [] for i in range(wcsimrootevent.GetNtrack()): if tracks[i].GetParenttype() == 0 and tracks[i].GetFlag( ) == 0 and tracks[i].GetIpnu() in Eth.keys(): pid.append(tracks[i].GetIpnu()) position.append([ tracks[i].GetStart(0), tracks[i].GetStart(1), tracks[i].GetStart(2) ]) direction.append([ tracks[i].GetDir(0), tracks[i].GetDir(1), tracks[i].GetDir(2) ]) energy.append(tracks[i].GetE()) biggestTrigger = 0 biggestTriggerDigihits = 0 for index in range(wcsimrootsuperevent.GetNumberOfEvents()): wcsimrootevent = wcsimrootsuperevent.GetTrigger(index) ncherenkovdigihits = wcsimrootevent.GetNcherenkovdigihits() if ncherenkovdigihits > biggestTriggerDigihits: biggestTriggerDigihits = ncherenkovdigihits biggestTrigger = index wcsimrootevent = wcsimrootsuperevent.GetTrigger(biggestTrigger) wcsimrootevent = wcsimrootsuperevent.GetTrigger(index) print "event date and number: " + str( wcsimrootevent.GetHeader().GetDate()) + " " + str( wcsimrootevent.GetHeader().GetEvtNum()) ncherenkovhits = wcsimrootevent.GetNcherenkovhits() ncherenkovdigihits = wcsimrootevent.GetNcherenkovdigihits() print "Ncherenkovdigihits " + str(ncherenkovdigihits) if ncherenkovdigihits == 0: print "event, trigger has no hits " + str(ev) + " " + str(index) return np_pos_x = np.zeros((ncherenkovdigihits)) np_pos_y = np.zeros((ncherenkovdigihits)) np_pos_z = np.zeros((ncherenkovdigihits)) np_dir_u = np.zeros((ncherenkovdigihits)) np_dir_v = np.zeros((ncherenkovdigihits)) np_dir_w = np.zeros((ncherenkovdigihits)) np_cylloc = np.zeros((ncherenkovdigihits)) np_cylloc = np_cylloc - 1000 np_q = np.zeros((ncherenkovdigihits)) np_t = np.zeros((ncherenkovdigihits)) np_pmt_index = np.zeros((ncherenkovdigihits), dtype=np.int32) """ The index starts at 1 and counts up continuously with no gaps Each 19 consecutive PMTs belong to one mPMT module, so (index-1)/19 is the module number. The index%19 gives the position in the module: 1-12 is the outer ring, 13-18 is the inner ring, 0 is the centre PMT The modules are then ordered as follows: It starts by going round the second highest ring around the barrel, then the third highest ring, fourth highest ring, all the way down to the lowest ring (i.e. skips the highest ring). Then does the bottom end-cap, row by row (the first row has 6 modules, the second row has 8, then 10, 10, 10, 10, 10, 10, 8, 6). Then the highest ring around the barrel that was skipped before, then the top end-cap, row by row. I'm not sure why it has this somewhat strange order... WTF: actually it is: 2, 6, 8 10, 10, 12 and down again in the caps """ for i in range(ncherenkovdigihits): wcsimrootcherenkovdigihit = wcsimrootevent.GetCherenkovDigiHits( ).At(i) hit_q = wcsimrootcherenkovdigihit.GetQ() hit_t = wcsimrootcherenkovdigihit.GetT() hit_tube_id = wcsimrootcherenkovdigihit.GetTubeId() - 1 np_pmt_index[i] = hit_tube_id #if i<10: # print "q t id: "+str(hit_q)+" "+str(hit_t)+" "+str(hit_tube_id)+" " pmt = geo.GetPMT(hit_tube_id) #if i<10: # print "pmt tube no: "+str(pmt.GetTubeNo()) #+" " +pmt.GetPMTName() # print "pmt cyl loc: "+str(pmt.GetCylLoc()) #np_cylloc[i]=pmt.GetCylLoc() np_pos_x[i] = pmt.GetPosition(2) np_pos_y[i] = pmt.GetPosition(0) np_pos_z[i] = pmt.GetPosition(1) np_dir_u[i] = pmt.GetOrientation(2) np_dir_v[i] = pmt.GetOrientation(0) np_dir_w[i] = pmt.GetOrientation(1) np_q[i] = hit_q np_t[i] = hit_t np_module_index = module_index(np_pmt_index) np_pmt_in_module_id = pmt_in_module_id(np_pmt_index) np_wall_indices = np.where(is_barrel(np_module_index)) np_top_indices = np.where(is_top(np_module_index)) np_bottom_indices = np.where(is_bottom(np_module_index)) np_pos_r = np.hypot(np_pos_x, np_pos_y) np_pos_phi = np.arctan2(np_pos_y, np_pos_x) np_pos_arc = r_max * np_pos_phi np_pos_arc_wall = np_pos_arc[np_wall_indices] np_pos_x_top = np_pos_x[np_top_indices] np_pos_y_top = np_pos_y[np_top_indices] np_pos_z_top = np_pos_z[np_top_indices] np_pos_x_bottom = np_pos_x[np_bottom_indices] np_pos_y_bottom = np_pos_y[np_bottom_indices] np_pos_z_bottom = np_pos_z[np_bottom_indices] np_pos_x_wall = np_pos_x[np_wall_indices] np_pos_y_wall = np_pos_y[np_wall_indices] np_pos_z_wall = np_pos_z[np_wall_indices] np_q_top = np_q[np_top_indices] np_t_top = np_t[np_top_indices] np_q_bottom = np_q[np_bottom_indices] np_t_bottom = np_t[np_bottom_indices] np_q_wall = np_q[np_wall_indices] np_t_wall = np_t[np_wall_indices] np_wall_row, np_wall_col = row_col(np_module_index[np_wall_indices]) np_pmt_in_module_id_wall = np_pmt_in_module_id[np_wall_indices] np_wall_data_rect = np.zeros((16, 40, 38)) np_wall_data_rect[np_wall_row, np_wall_col, np_pmt_in_module_id_wall] = np_q_wall np_wall_data_rect[np_wall_row, np_wall_col, np_pmt_in_module_id_wall + 19] = np_t_wall np_wall_data_rect_ev = np.expand_dims(np_wall_data_rect, axis=0) np_wall_q_max_module = np.amax(np_wall_data_rect[:, :, 0:19], axis=-1) np_wall_q_sum_module = np.sum(np_wall_data_rect[:, :, 0:19], axis=-1) max_q = np.amax(np_q) np_scaled_q = 500 * np_q / max_q np_dir_u_scaled = np_dir_u * np_scaled_q np_dir_v_scaled = np_dir_v * np_scaled_q np_dir_w_scaled = np_dir_w * np_scaled_q fig1 = plt.figure(num=1, clear=True) fig1.set_size_inches(10, 8) ax1 = fig1.add_subplot(111, projection='3d', azim=35, elev=20) ev_disp = ax1.scatter(np_pos_x, np_pos_y, np_pos_z, c=np_q, s=2, alpha=0.4, cmap=cm, marker='.') ax1.set_xlabel('x') ax1.set_ylabel('y') ax1.set_zlabel('z') cb_ev_disp = fig1.colorbar(ev_disp, pad=0.03) cb_ev_disp.set_label("charge") fig1.savefig(write_dir + "ev_disp_ev_{}_trig_{}.pdf".format(ev, index)) fig2 = plt.figure(num=2, clear=True) fig2.set_size_inches(10, 8) ax2 = fig2.add_subplot(111, projection='3d', azim=35, elev=20) colors = plt.cm.spring(norm(np_t)) ev_disp_q = ax2.quiver(np_pos_x, np_pos_y, np_pos_z, np_dir_u_scaled, np_dir_v_scaled, np_dir_w_scaled, colors=colors, alpha=0.4, cmap=cm) ax2.set_xlabel('x') ax2.set_ylabel('y') ax2.set_zlabel('z') sm = matplotlib.cm.ScalarMappable(cmap=cm, norm=norm) sm.set_array([]) cb_ev_disp_2 = fig2.colorbar(sm, pad=0.03) cb_ev_disp_2.set_label("time") fig2.savefig(write_dir + "ev_disp_quiver_ev_{}_trig_{}.pdf".format(ev, index)) fig3 = plt.figure(num=3, clear=True) fig3.set_size_inches(10, 8) ax3 = fig3.add_subplot(111) ev_disp_wall = ax3.scatter(np_pos_arc_wall, np_pos_z_wall, c=np_q_wall, s=2, cmap=cm, marker='.') ax3.set_xlabel('arc along the wall') ax3.set_ylabel('z') cb_ev_disp_wall = fig3.colorbar(ev_disp_wall, pad=0.1) cb_ev_disp_wall.set_label("charge") fig3.savefig(write_dir + "ev_disp_wall_ev_{}_trig_{}.pdf".format(ev, index)) fig4 = plt.figure(num=4, clear=True) fig4.set_size_inches(10, 8) ax4 = fig4.add_subplot(111) ev_disp_top = ax4.scatter(np_pos_x_top, np_pos_y_top, c=np_q_top, s=2, cmap=cm, marker='.') ax4.set_xlabel('x') ax4.set_ylabel('y') cb_ev_disp_top = fig4.colorbar(ev_disp_top, pad=0.1) cb_ev_disp_top.set_label("charge") fig4.savefig(write_dir + "ev_disp_top_ev_{}_trig_{}.pdf".format(ev, index)) fig5 = plt.figure(num=5, clear=True) fig5.set_size_inches(10, 8) ax5 = fig5.add_subplot(111) ev_disp_bottom = ax5.scatter(np_pos_x_bottom, np_pos_y_bottom, c=np_q_bottom, s=2, cmap=cm, marker='.') ax5.set_xlabel('x') ax5.set_ylabel('y') cb_ev_disp_bottom = fig5.colorbar(ev_disp_bottom, pad=0.1) cb_ev_disp_bottom.set_label("charge") fig5.savefig(write_dir + "ev_disp_bottom_ev_{}_trig_{}.pdf".format(ev, index)) fig6 = plt.figure(num=6, clear=True) fig6.set_size_inches(10, 4) ax6 = fig6.add_subplot(111) q_sum_disp = ax6.imshow(np.flip(np_wall_q_sum_module, axis=0), cmap=cm) ax6.set_xlabel('arc index') ax6.set_ylabel('z index') cb_q_sum_disp = fig6.colorbar(q_sum_disp, pad=0.1) cb_q_sum_disp.set_label("total charge in module") fig6.savefig(write_dir + "q_sum_disp_ev_{}_trig_{}.pdf".format(ev, index)) fig7 = plt.figure(num=7, clear=True) fig7.set_size_inches(10, 4) ax7 = fig7.add_subplot(111) q_max_disp = ax7.imshow(np.flip(np_wall_q_max_module, axis=0), cmap=cm) ax7.set_xlabel('arc index') ax7.set_ylabel('z index') cb_q_max_disp = fig7.colorbar(q_max_disp, pad=0.1) cb_q_max_disp.set_label("maximum charge in module") fig7.savefig(write_dir + "q_max_disp_ev_{}_trig_{}.pdf".format(ev, index)) fig8 = plt.figure(num=8, clear=True) fig8.set_size_inches(10, 8) ax8 = fig8.add_subplot(111) plt.hist(np_q, 50, density=True, facecolor='blue', alpha=0.75) ax8.set_xlabel('charge') ax8.set_ylabel("PMT's above threshold") fig8.savefig(write_dir + "q_pmt_disp_ev_{}_trig_{}.pdf".format(ev, index)) fig9 = plt.figure(num=9, clear=True) fig9.set_size_inches(10, 8) ax9 = fig9.add_subplot(111) plt.hist(np_t, 50, density=True, facecolor='blue', alpha=0.75) ax9.set_xlabel('time') ax9.set_ylabel("PMT's above threshold") fig9.savefig(write_dir + "t_pmt_disp_ev_{}_trig_{}.pdf".format(ev, index)) fig10 = plt.figure(num=10, clear=True) fig10.set_size_inches(15, 5) grid_q = ImageGrid( fig10, 111, nrows_ncols=(4, 5), axes_pad=0.0, share_all=True, label_mode="L", cbar_location="top", cbar_mode="single", ) for i in range(19): q_disp = grid_q[i].imshow(np.flip(np_wall_data_rect[:, :, i], axis=0), cmap=cm) q_disp = grid_q[19].imshow(np.flip(np_wall_q_max_module, axis=0), cmap=cm) grid_q.cbar_axes[0].colorbar(q_disp) fig10.savefig(write_dir + "q_disp_grid_ev_{}_trig_{}.pdf".format(ev, index)) fig11 = plt.figure(num=11, clear=True) fig11.set_size_inches(15, 5) grid_t = ImageGrid( fig11, 111, nrows_ncols=(4, 5), axes_pad=0.0, share_all=True, label_mode="L", cbar_location="top", cbar_mode="single", ) for i in range(19): t_disp = grid_t[i].imshow(np.flip(np_wall_data_rect[:, :, i + 19], axis=0), cmap=cm) fig11.savefig(write_dir + "t_disp_grid_ev_{}_trig_{}.pdf".format(ev, index)) wl.close()
def plot_av_vs_nhi_grid(nhi_images, av_images, nhi_error_images=None, av_error_images=None, limits=None, savedir='./', filename=None, show=False, scale=['linear', 'linear'], returnimage=False, hess_binsize=None, title='', plot_type='hexbin', color_scale='linear'): # Import external modules import numpy as np import math import pyfits as pf import matplotlib.pyplot as plt import matplotlib from mpl_toolkits.axes_grid1 import ImageGrid n = int(np.ceil(len(av_images)**0.5)) if n**2 - n > len(av_images): nrows = n - 1 ncols = n y_scaling = 1.0 - 1.0 / n else: nrows, ncols = n, n y_scaling = 1.0 # Set up plot aesthetics plt.clf() plt.rcdefaults() colormap = plt.cm.gist_ncar #color_cycle = [colormap(i) for i in np.linspace(0, 0.9, len(flux_list))] font_scale = 12 params = { #'backend': .pdf', 'axes.labelsize': font_scale, 'axes.titlesize': font_scale, 'text.fontsize': font_scale, 'legend.fontsize': font_scale * 3 / 4.0, 'xtick.labelsize': font_scale, 'ytick.labelsize': font_scale, 'font.weight': 500, 'axes.labelweight': 500, 'text.usetex': True, 'figure.figsize': (8, 8 * y_scaling), #'axes.color_cycle': color_cycle # colors of different plots } plt.rcParams.update(params) # Create figure instance fig = plt.figure() imagegrid = ImageGrid(fig, (1, 1, 1), nrows_ncols=(nrows, ncols), ngrids=len(av_images), axes_pad=0.25, aspect=False, label_mode='L', share_all=True) # Cycle through lists for i in xrange(len(av_images)): av = av_images[i] nhi = nhi_images[i] av_error = av_error_images[i] nhi_error = nhi_error_images[i] #av_fit = av_fits[i] #nhi_fit = nhi_fits[i] # Drop the NaNs from the images if type(av_error) is float: indices = np.where((av == av) &\ (nhi == nhi)&\ (nhi > 0) &\ (av > 0)) if type(av_error) is np.ndarray or \ type(av_error) is np.ma.core.MaskedArray or \ type(nhi_error) is np.ndarray or \ type(nhi_error) is np.ma.core.MaskedArray: indices = np.where((av == av) &\ (nhi == nhi) &\ (nhi_error == nhi_error) &\ (av_error == av_error) &\ (nhi > 0) &\ (av > 0)) av_nonans = av[indices] nhi_nonans = nhi[indices] if type(av_error) is np.ndarray: av_error_nonans = av_error[indices] else: av_error_nonans = np.array(av_error[indices]) if type(nhi_error) is np.ndarray or \ type(nhi_error) is np.ma.core.MaskedArray: nhi_error_nonans = nhi_error[indices] else: nhi_error_nonans = nhi_error * \ np.ones(nhi[indices].shape) # Create plot ax = imagegrid[i] image = ax.errorbar(nhi_nonans.ravel(), av_nonans.ravel(), xerr=(nhi_error_nonans.ravel()), yerr=(av_error_nonans.ravel()), alpha=0.3, color='k', marker='^', ecolor='k', linestyle='none', markersize=4) #if av_fit is not None: # ax.plot(nhi_fit, av_fit, # color = 'r') # Annotations anno_xpos = 0.95 ''' if phi_cnm_list is not None and Z_list is not None: if phi_cnm_error_list is None and Z_error_list is not None: ax.annotate(r'$\phi_{\rm CNM}$ = {0:.2f}\n'.format(phi_cnm) + \ r'Z = {0:.2f} Z$_\odot$'.format(Z), xytext=(anno_xpos, 0.05), xy=(anno_xpos, 0.05), textcoords='axes fraction', xycoords='axes fraction', color='k', bbox=dict(boxstyle='round', facecolor='w', alpha=0.5), horizontalalignment='right', verticalalignment='bottom', ) else: ax.annotate(r'\noindent$\phi_{\rm CNM}$ =' + \ r' %.2f' % (phi_cnm) + \ r'$^{+%.2f}_{-%.2f}$ \\' % (phi_cnm_error[0], phi_cnm_error[1]) + \ r'Z = %.2f' % (Z) + \ r'$^{+%.2f}_{-%.2f}$ Z$_\odot$' % (Z_error[0], Z_error[1]) + \ r'', xytext=(anno_xpos, 0.05), xy=(anno_xpos, 0.05), textcoords='axes fraction', xycoords='axes fraction', size=font_scale*3/4.0, color='k', bbox=dict(boxstyle='round', facecolor='w', alpha=1), horizontalalignment='right', verticalalignment='bottom', ) ''' ax.set_xscale(scale[0], nonposx='clip') ax.set_yscale(scale[1], nonposy='clip') if limits is not None: ax.set_xlim(limits[0], limits[1]) ax.set_ylim(limits[2], limits[3]) # Adjust asthetics ax.set_xlabel(r'$N(HI)$ (10$^{20}$ cm$^{-2}$)') ax.set_ylabel(r'A$_{\rm V}$ (mag)') ax.set_title(title) ax.grid(True) if title is not None: fig.suptitle(title, fontsize=font_scale * 1.5) if filename is not None: plt.savefig(savedir + filename) #, bbox_inches='tight') if show: fig.show()
def attention_epoch_plot(net, folder_name, source_images, logged=False, width=5, device=torch.device('cpu'), layer_name_base='attention', layer_no=2, cmap_name='magma', figsize=(100, 100)): """ Function for plotting clean grid of attention maps as they develop throughout the learning stages. Args: The attention map data, original images of sources number of unique sources, if you want your image logged, number of output attentions desired (sampled evenly accross available space) epoch labels of when the images were extracted Out: plt of images concatenated in correct fashion """ # cmap_name and RGB potential if cmap_name == 'RGB': mean_ = False cmap_name = 'magma' # Generate attention maps for each available Epoch attention_maps_temp, og_attention_maps, epoch_labels = AttentionImagesByEpoch( source_images, folder_name, net, epoch=2000, device=device, layer_name_base=layer_name_base, layer_no=layer_no, mean=mean_) # Extract terms to be used in plotting sample_number = source_images.shape[0] no_saved_attentions_epochs = np.asarray( attention_maps_temp).shape[0] // sample_number attentions = np.asarray(attention_maps_temp) imgs = [] labels = [] width_array = range(no_saved_attentions_epochs) if width <= no_saved_attentions_epochs: width_array = np.linspace(0, no_saved_attentions_epochs - 1, num=width, dtype=np.int32) else: width = no_saved_attentions_epochs # Prepare the selection of images in the correct order as to be plotted reasonably (and prepare epoch labels) for j in range(sample_number): if logged: imgs.append(np.exp(source_images[j].squeeze())) else: imgs.append(source_images[j].squeeze()) for i in width_array: #print(sample_number,i,j) imgs.append(attention_maps_temp[sample_number * i + j]) try: labels[width - 1] except: labels.append(epoch_labels[sample_number * i]) # Define the plot of the grid of images fig = plt.figure(figsize=figsize) grid = ImageGrid( fig, 111, nrows_ncols=(sample_number, width + 1), axes_pad=0.02, # pad between axes in inch. ) for idx, (ax, im) in enumerate(zip(grid, imgs)): # Transpose for RGB image if im.shape[0] == 3: im = im.transpose(1, 2, 0) # Plot image if logged: ax.imshow(np.log(im), cmap=cmap_name) else: ax.imshow(im, cmap=cmap_name) # Plot contour if image is source image if idx % (width + 1) == 0: ax.contour(im, 1, cmap='cool', alpha=0.5) ax.axis('off') print( f'Source images followed by their respective averaged attention maps at epochs:\n{labels}' ) plt.show()
def plot_av_vs_nhi(nhi_image, av_image, limits=None, savedir='./', filename=None, show=False, scale=['linear', 'linear'], returnimage=False, hess_binsize=None, title='', plot_type='hexbin', color_scale='linear'): # Import external modules import numpy as np import math import pyfits as pf import matplotlib.pyplot as plt import matplotlib from matplotlib import cm from mpl_toolkits.axes_grid1 import ImageGrid # Drop the NaNs from the images indices = np.where((nhi_image == nhi_image) &\ (av_image == av_image) &\ (nhi_image > 0) &\ (av_image > 0)) try: nhi_image_nonans = nhi_image[indices] av_image_nonans = av_image[indices] if type(av_image_error) is float: av_image_error_nonans = sd_image_error * \ np.ones(av_image[indices].shape) else: av_image_error_nonans = sd_image_error[indices] if type(nhi_image_error) is np.ndarray: nhi_image_error_nonans = nhi_image_error[indices] else: nhi_image_error_nonans = nhi_image_error * \ np.ones(nhi_image[indices].shape) except NameError: no_errors = True # Create figure # Set up plot aesthetics plt.clf() plt.rcdefaults() colormap = plt.cm.gist_ncar #color_cycle = [colormap(i) for i in np.linspace(0, 0.9, len(flux_list))] fig_size = (4, 4) font_scale = 10 params = { #'backend': .pdf', 'axes.labelsize': font_scale, 'axes.titlesize': font_scale, 'text.fontsize': font_scale, 'legend.fontsize': font_scale * 3 / 4, 'xtick.labelsize': font_scale, 'ytick.labelsize': font_scale, 'font.weight': 500, 'axes.labelweight': 500, 'text.usetex': False, 'figure.figsize': fig_size, #'axes.color_cycle': color_cycle # colors of different plots } plt.rcParams.update(params) if plot_type == 'scatter': cbar_mode = 'None' else: cbar_mode = 'single' # Create figure plt.clf() fig = plt.figure() imagegrid = ImageGrid(fig, (1, 1, 1), nrows_ncols=(1, 1), ngrids=1, axes_pad=0.25, aspect=False, label_mode='L', share_all=True, cbar_mode=cbar_mode, cbar_pad=0.1, cbar_size=0.2) ax = imagegrid[0] if plot_type is 'hexbin': if color_scale == 'linear': image = ax.hexbin(nhi_image_nonans.ravel(), av_image_nonans.ravel(), mincnt=1, xscale=scale[0], yscale=scale[1], cmap=cm.gist_stern) cb = ax.cax.colorbar(image, ) # Write label to colorbar cb.set_label_text('Bin Counts', ) elif color_scale == 'log': image = ax.hexbin(nhi_image_nonans.ravel(), av_image_nonans.ravel(), norm=matplotlib.colors.LogNorm(), mincnt=1, xscale=scale[0], yscale=scale[1], gridsize=(100, 200), cmap=cm.gist_stern) cb = ax.cax.colorbar(image, ) # Write label to colorbar cb.set_label_text('Bin Counts', ) # Adjust color bar of density plot #cb = image.colorbar(image) #cb.set_label('Bin Counts') elif plot_type is 'scatter': image = ax.scatter(nhi_image_nonans.ravel(), av_image_nonans.ravel(), alpha=0.3, color='k') ax.set_xscale(scale[0]) ax.set_yscale(scale[1]) if limits is not None: ax.set_xlim(limits[0], limits[1]) ax.set_ylim(limits[2], limits[3]) # Adjust asthetics ax.set_xlabel(r'$N(HI)$ (10$^{20}$ cm$^{-2}$)') ax.set_ylabel(r'A$_{\rm V}$ (mag)') ax.set_title(title) ax.grid(True) if filename is not None: plt.savefig(savedir + filename, bbox_inches='tight', dpi=600) if show: fig.show() if returnimage: return correlations_image
def getFigAx(subplot, name=None, title=None, figsize=None, mpl=None, margins=None, sharex=None, sharey=None, AxesGrid=False, ngrids=None, direction='row', axes_pad = None, add_all=True, share_all=None, aspect=False, label_mode='L', cbar_mode=None, cbar_location='right', cbar_pad=None, cbar_size='5%', axes_class=None, lreduce=True): # configure matplotlib warn('Deprecated function: use Figure or Axes class methods.') if mpl is None: import matplotlib as mpl elif isinstance(mpl,dict): mpl = loadMPL(**mpl) # there can be a mplrc, but also others elif not isinstance(mpl,ModuleType): raise TypeError from plotting.figure import MyFigure # prevent circular reference # figure out subplots if isinstance(subplot,(np.integer,int)): if subplot == 1: subplot = (1,1) elif subplot == 2: subplot = (1,2) elif subplot == 3: subplot = (1,3) elif subplot == 4: subplot = (2,2) elif subplot == 6: subplot = (2,3) elif subplot == 9: subplot = (3,3) else: raise NotImplementedError elif not (isinstance(subplot,(tuple,list)) and len(subplot) == 2) and all(isInt(subplot)): raise TypeError # create figure if figsize is None: if subplot == (1,1): figsize = (3.75,3.75) elif subplot == (1,2) or subplot == (1,3): figsize = (6.25,3.75) elif subplot == (2,1) or subplot == (3,1): figsize = (3.75,6.25) else: figsize = (6.25,6.25) #elif subplot == (2,2) or subplot == (3,3): figsize = (6.25,6.25) #else: raise NotImplementedError # figure out margins if margins is None: # N.B.: the rectangle definition is presumably left, bottom, width, height if subplot == (1,1): margins = (0.09,0.09,0.88,0.88) elif subplot == (1,2) or subplot == (1,3): margins = (0.06,0.1,0.92,0.87) elif subplot == (2,1) or subplot == (3,1): margins = (0.09,0.11,0.88,0.82) elif subplot == (2,2) or subplot == (3,3): margins = (0.055,0.055,0.925,0.925) else: margins = (0.09,0.11,0.88,0.82) #elif subplot == (2,2) or subplot == (3,3): margins = (0.09,0.11,0.88,0.82) #else: raise NotImplementedError if title is not None: margins = margins[:3]+(margins[3]-0.03,) # make room for title if AxesGrid: if share_all is None: share_all = True if axes_pad is None: axes_pad = 0.05 # create axes using the Axes Grid package fig = mpl.pylab.figure(facecolor='white', figsize=figsize, FigureClass=MyFigure) if axes_class is None: from plotting.axes import MyLocatableAxes axes_class=(MyLocatableAxes,{}) from mpl_toolkits.axes_grid1 import ImageGrid # AxesGrid: http://matplotlib.org/mpl_toolkits/axes_grid/users/overview.html grid = ImageGrid(fig, margins, nrows_ncols = subplot, ngrids=ngrids, direction=direction, axes_pad=axes_pad, add_all=add_all, share_all=share_all, aspect=aspect, label_mode=label_mode, cbar_mode=cbar_mode, cbar_location=cbar_location, cbar_pad=cbar_pad, cbar_size=cbar_size, axes_class=axes_class) # return figure and axes axes = tuple([ax for ax in grid]) # this is already flattened if lreduce and len(axes) == 1: axes = axes[0] # return a bare axes instance, if there is only one axes else: # create axes using normal subplot routine if axes_pad is None: axes_pad = 0.03 wspace = hspace = axes_pad if share_all: sharex='all'; sharey='all' if sharex is True or sharex is None: sharex = 'col' # default if sharey is True or sharey is None: sharey = 'row' if sharex: hspace -= 0.015 if sharey: wspace -= 0.015 # create figure from matplotlib.pyplot import subplots # GridSpec: http://matplotlib.org/users/gridspec.html fig, axes = subplots(subplot[0], subplot[1], sharex=sharex, sharey=sharey, squeeze=lreduce, facecolor='white', figsize=figsize, FigureClass=MyFigure) # there is also a subplot_kw=dict() and fig_kw=dict() # just adjust margins margin_dict = dict(left=margins[0], bottom=margins[1], right=margins[0]+margins[2], top=margins[1]+margins[3], wspace=wspace, hspace=hspace) fig.subplots_adjust(**margin_dict) # add figure title if name is not None: fig.canvas.set_window_title(name) # window title if title is not None: fig.suptitle(title) # title on figure (printable) # return Figure/ImageGrid and tuple of axes #if AxesGrid: fig = grid # return ImageGrid instead of figure return fig, axes
def tiled_axis(ts, filename=None): fig = Figure((2.56 * 4, 2.56 * 4), 300) canvas = FigureCanvas(fig) #ax = fig.add_subplot(111) grid = ImageGrid( fig, 111, # similar to subplot(111) nrows_ncols=(3, 1), axes_pad=0.5, add_all=True, label_mode="L", ) # pad half a day so that major ticks show up in the middle, not on the edges delta = dates.relativedelta(days=2, hours=12) # XXX: gets a list of days. timestamps = glucose.get_days(ts.time) xmin, xmax = (timestamps[0] - delta, timestamps[-1] + delta) fig.autofmt_xdate() def make_plot(ax, limit): preferspan = ax.axhspan(SAFE[0], SAFE[1], facecolor='g', alpha=0.2, edgecolor='#003333', linewidth=1) def draw_glucose(ax, limit): xmin, xmax = limit # visualize glucose using stems ax.set_xlim([xmin, xmax]) markers, stems, baselines = ax.stem(ts.time, ts.value, linefmt='b:') plt.setp(markers, color='red', linewidth=.5, marker='o') plt.setp(baselines, marker='None') def draw_title(ax, limit): ax.set_title('glucose history') def get_axis(ax, limit): xmin, xmax = limit ax.set_xlim([xmin, xmax]) ax.grid(True) #ax.set_ylim( [ ts.value.min( ) *.85 , 600 ] ) #ax.set_xlabel('time') majorLocator = dates.DayLocator() majorFormatter = dates.AutoDateFormatter(majorLocator) minorLocator = dates.HourLocator(interval=6) minorFormatter = dates.AutoDateFormatter(minorLocator) #ax.xaxis.set_major_locator(majorLocator) #ax.xaxis.set_major_formatter(majorFormatter) ax.xaxis.set_minor_locator(minorLocator) ax.xaxis.set_minor_formatter(minorFormatter) labels = ax.get_xminorticklabels() plt.setp(labels, rotation=30, fontsize='small') plt.setp(ax.get_xmajorticklabels(), rotation=30, fontsize='medium') xmin, xmax = ax.get_xlim() log.info( pformat({ 'xlim': [dates.num2date(xmin), dates.num2date(xmax)], 'xticks': dates.num2date(ax.get_xticks()), })) for i, day in enumerate(timestamps): ax = grid[i] get_axis(ax, [day, day + delta]) name = '%s-%d.png' % (day.isoformat(), i) #fig.savefig( name ) canvas.print_figure(name) # fig.clf() #make_plot( ax, #ax.set_ylabel('glucose mm/dL') return canvas
import matplotlib matplotlib.rcParams['mathtext.fontset'] = 'cm' matplotlib.rcParams['mathtext.rm'] = 'serif' # n_snap = 0 for n_snap in range(0, 56): # Set up figure and image grid fig = plt.figure(figsize=(fig_width * n_cols, 10 * n_rows), ) grid = ImageGrid( fig, 111, # as in plt.subplot(111) nrows_ncols=(n_rows, n_cols), axes_pad=0.2, share_all=True, cbar_location="right", cbar_mode="single", cbar_size="5%", cbar_pad=0.1, ) colormap = 'turbo' alpha = 0.6 x_min, x_max = -3, 5.5 y_min, y_max = 1.8, 8. ax = grid[0] dens_points = data_all[n_snap]['dens_points']
y_test = np_utils.to_categorical(y_test, num_classes) # matplotlib inline # cat_y = np.argmax(y_train, 1) f, ax = plt.subplots() n, bins, patches = ax.hist(cat_y, bins=range(11), align='left', rwidth=0.5) ax.set_xticks(bins[:-1]) ax.set_xlabel('Class Id') ax.set_ylabel('# Samples') ax.set_title('CIFAR-10 Class Distribution') fig = plt.figure(figsize=(10., 10.)) grid = ImageGrid( fig, 111, # similar to subplot(111) nrows_ncols=(10, 10), # creates 10x10 grid of axes axes_pad=(0.03, 0.1), # pad between axes in inch ) for i in range(10): cls_id = i cat_im = X_train[cat_y == cls_id] for j in range(10): im = cat_im[j] im = im.squeeze() ax = grid[10 * i + j] ax.imshow(im, cmap='gray') ax.axis('off') ax.grid(False) print(j)
from orimat import qvec from settings import e0, e1, e2, hist_bincenters, hist_binedges, a0_reciprocal do_plotsurface = False # set to True to plot surface orientation runs = [4, 3, 8, 9, 7] qarray = hist_bincenters[0] * a0_reciprocal xmin, xmax = hist_binedges[1][0] * a0_reciprocal, hist_binedges[1][ -1] * a0_reciprocal ymin, ymax = hist_binedges[0][0] * a0_reciprocal, hist_binedges[0][ -1] * a0_reciprocal ## plot the 2d images fig = plt.figure(figsize=(6, 8)) grid = ImageGrid(fig, 111, nrows_ncols=(1,5), axes_pad=0.15, \ share_all=True, cbar_location='right', \ cbar_mode='single', cbar_size="15%", cbar_pad=0.15) for i_run, run in enumerate(runs): ax = grid[i_run] filename = helper.getparam("filenames", run) sample = helper.getparam("samples", runs[i_run]) orimat = np.load("orimats/%s.npy" % (helper.getparam("orimats", run))) data_orig = np.load("data_hklmat/%s_interpolated_xy.npy" % filename) im = ax.imshow(np.log10(data_orig), extent=[xmin, xmax, ymin, ymax], \ origin='lower', interpolation='nearest', \ vmax=3.0) if do_plotsurface: surf_norm = orimat.dot(qvec(50., 25., 90., 0.)) surf_norm = surf_norm / np.linalg.norm(surf_norm) print "hkl of surface =", surf_norm
fig = plt.figure(figsize=(6, 6)) # Prepare images Z = cbook.get_sample_data("axes_grid/bivariate_normal.npy", np_load=True) extent = (-3, 4, -4, 3) ZS = [Z[i::3, :] for i in range(3)] extent = extent[0], extent[1] / 3., extent[2], extent[3] # *** Demo 1: colorbar at each axes *** grid = ImageGrid( fig, 211, # similar to subplot(211) nrows_ncols=(1, 3), axes_pad=0.05, label_mode="1", share_all=True, cbar_location="top", cbar_mode="each", cbar_size="7%", cbar_pad="1%", ) for i, (ax, z) in enumerate(zip(grid, ZS)): im = ax.imshow(z, origin="lower", extent=extent) cb = ax.cax.colorbar(im) # Changing the colorbar ticks if i in [1, 2]: cb.set_ticks([-1, 0, 1]) for ax, im_title in zip(grid, ["Image 1", "Image 2", "Image 3"]): t = add_inner_title(ax, im_title, loc='lower left')
def create_channelmap(raster=None, contour=None, clevels=None, zeropoint=0., channels=None, ncols=4, vrange=None, vscale=None, show=True, pdfname=None, cbarlabel=None, cmap='RdYlBu_r', beam=True, **kwargs): """ Running this function will create a quick channel map of the Qube. One can either plot the contours or the raster image or both. This program can be used as a basis for a more detailed individual plot which can take better care of whitespace, etc. The following keywords are valid: raster: qube object used to generate the raster image, leave blank or 'None' if not desired. contour: qube object used to generate the contours, leave blank or 'None' if not desired. clevels: nested lists of contour levels to draw, list should be the same length as the spectral dimension of the contour qube, if a single list is given assumes that the contours are the same for all channels. zeropoint: Optional shift in velocities compared to the Restfrq keyword in the header of the Qube. ncols: number of columns to plot vrange: range of the z-axis (for consistency across panels) if None, will take minimum maximum of the Qube vscale: division of the colorbar (if none will make ~5 divisions) cbarlabel: if set, label of the color bar cmap: colormap to use for the plotting pdfname: if set, the name of the pdf file to which the image was saved Returns: fig """ # generate a temporary qube from the data if raster is not None: tqube = raster elif contour is not None: tqube = contour else: raise ValueError('Need to define either a contour or raster image.') # first genererate the velocity array VelArr = tqube._getvelocity_() - zeropoint # define the number of channels, rows (and columns = already defined) if channels is None: channels = np.arange(tqube.shape[0]) nrows = np.ceil(len(channels) / ncols).astype(int) # define the range of the z-axis (if needed) if vrange is None: vrange = [np.nanmin(tqube.data), np.nanmax(tqube.data)] # now generate a grid with the specified number of columns fig = plt.figure(1, (8., 8. / ncols * nrows)) grid = ImageGrid(fig, 111, nrows_ncols=(nrows, ncols), axes_pad=0., cbar_mode='single', cbar_location='right', share_all=True) # now loop over the channels for idx, chan in enumerate(channels): # get the string value of the velocity VelStr = str(int(round(VelArr[chan]))) + ' km s$^{-1}$' # get the boolean of the beam (bottom left figure only) if (idx % ncols == 0) and (idx // ncols == int(nrows) - 1) and beam: beambool = True else: beambool = False # plot the individual channels if raster is not None: rasterimage = raster.get_slice(zindex=(chan, chan + 1)) else: rasterimage = None if contour is not None: contourimage = contour.get_slice(zindex=(chan, chan + 1)) # also get the contour levels if clevels is None: raise ValueError('Set the contour levels using the clevels ' + 'keyword') elif type(clevels[0]) is list or type(clevels[0]) is np.ndarray: clevel = clevels[chan] else: clevel = clevels else: contourimage = None clevel = clevels standardfig(raster=rasterimage, contour=contourimage, clevels=clevel, newplot=False, fig=fig, ax=grid[idx], cbar=False, beam=beambool, vrange=vrange, text=[VelStr], cmap=cmap, **kwargs) # now do the color bar norm = mpl.colors.Normalize(vmin=vrange[0], vmax=vrange[1]) cmapo = plt.cm.ScalarMappable(cmap=cmap, norm=norm) cmapo.set_array([]) if vscale is None: vscale = (vrange[1] - vrange[0]) / 5. cbr = plt.colorbar(cmapo, cax=grid.cbar_axes[0], ticks=np.arange(-10, 10) * vscale) # label of color bar if cbarlabel is not None: cbr.ax.set_ylabel(cbarlabel, labelpad=-1) if pdfname is not None: plt.savefig(pdfname, format='pdf', dpi=300) elif show: plt.show() else: pass return fig, grid, channels
def ransac_im_fit(im, mode=1, residual_threshold=0.1, min_samples=10, max_trials=1000, model_f=None, p0=None, mask=None, scale=False, fract=1, param_dict=None, plot=False, axes=(-2, -1), widget=None): ''' Fits a plane, polynomial, convex paraboloid, arbitrary function, or smoothing spline to an image using the RANSAC algorithm. Parameters ---------- im : ndarray ndarray with images to fit to. mode : integer [0-4] Specifies model used for fit. 0 is function defined by `model_f`. 1 is plane. 2 is quadratic. 3 is concave paraboloid with offset. 4 is smoothing spline. model_f : callable or None Function to be fitted. Definition is model_f(p, *args), where p is 1-D iterable of params and args is iterable of (x, y, z) arrays of point cloud coordinates. See examples. p0 : tuple Initial guess of fit params for `model_f`. mask : 2-D boolean array Array with which to mask data. True values are ignored. scale : bool If True, `residual_threshold` is scaled by stdev of `im`. fract : scalar (0, 1] Fraction of data used for fitting, chosen randomly. Non-used data locations are set as nans in `inliers`. residual_threshold : float Maximum distance for a data point to be classified as an inlier. min_samples : int or float The minimum number of data points to fit a model to. If an int, the value is the number of pixels. If a float, the value is a fraction (0.0, 1.0] of the total number of pixels. max_trials : int, optional Maximum number of iterations for random sample selection. param_dict : None or dictionary. If not None, the dictionary is passed to the model estimator. For arbitrary functions, this is passed to scipy.optimize.leastsq. For spline fitting, this is passed to scipy.interpolate.SmoothBivariateSpline. All other models take no parameters. plot : bool If True, the data, including inliers, model, etc are plotted. axes : length 2 iterable Indices of the input array with images. widget : QtWidget A qt widget that must contain as many widgets or dock widget as there is popup window Returns ------- Tuple of fit, inliers, n, where: fit : 2-D array Image of fitted model. inliers : 2-D array Boolean array describing inliers. n : array or None Normal of plane fit. `None` for other models. Notes ----- See skimage.measure.ransac for details of RANSAC algorithm. `min_samples` should be chosen appropriate to the size of the image and to the variation in the image. Increasing `residual_threshold` increases the fraction of the image fitted to. The entire image can be fitted to without RANSAC by setting: max_trials=1, min_samples=1.0, residual_threshold=`x`, where `x` is a suitably large value. Examples -------- `model_f` for paraboloid with offset: >>> def model_f(p, *args): ... (x, y, z) = args ... m = np.abs(p[0])*((x-p[1])**2 + (y-p[2])**2) + p[3] ... return m >>> p0 = (0.1, 10, 20, 0) To plot fit, inliers etc: >>> from fpd.ransac_tools import ransac_im_fit >>> import matplotlib as mpl >>> from numpy.ma import masked_where >>> import numpy as np >>> import matplotlib.pylab as plt >>> plt.ion() >>> cmap = mpl.cm.gray >>> cmap.set_bad('r') >>> image = np.random.rand(*(64,)*2) >>> fit, inliers, n = ransac_im_fit(image, mode=1) >>> cor_im = image-fit >>> pct = 0.5 >>> vmin, vmax = np.percentile(cor_im, [pct, 100-pct]) >>> >>> f, axs = plt.subplots(1, 4, sharex=True, sharey=True) >>> _ = axs[0].matshow(image, cmap=cmap) >>> _ = axs[1].matshow(masked_where(inliers==False, image), cmap=cmap) >>> _ = axs[2].matshow(fit, cmap=cmap) >>> _ = axs[3].matshow(cor_im, vmin=vmin, vmax=vmax) To plot plane normal vs threshold: >>> from fpd.ransac_tools import ransac_im_fit >>> from numpy.ma import masked_where >>> import numpy as np >>> from tqdm import tqdm >>> import matplotlib.pylab as plt >>> plt.ion() >>> image = np.random.rand(*(64,)*2) >>> ns = [] >>> rts = np.logspace(0, 1.5, 5) >>> for rt in tqdm(rts): ... nis = [] ... for i in range(64): ... fit, inliers, n = ransac_im_fit(image, residual_threshold=rt, max_trials=10) ... nis.append(n) ... ns.append(np.array(nis).mean(0)) >>> ns = np.array(ns) >>> thx = np.arctan2(ns[:,1], ns[:,2]) >>> thy = np.arctan2(ns[:,0], ns[:,2]) >>> thx = np.rad2deg(thx) >>> thy = np.rad2deg(thy) >>> _ = plt.figure() >>> _ = plt.semilogx(rts, thx) >>> _ = plt.semilogx(rts, thy) ''' # Set model # Functions defining classes are needed to pass parameters since class must # not be instantiated or are monkey patched (only in spline implementation) if mode == 0: # generate model_class with passed function if p0 is None: raise NotImplementedError('p0 must be specified.') model_class = rt._model_class_gen(model_f, p0, param_dict) elif mode == 1: # linear model_class = rt._Plane3dModel elif mode == 2: # quadratic model_class = rt._Poly3dModel elif mode == 3: # concave paraboloid model_class = rt._Poly3dParaboloidModel elif mode == 4: # spline class _Spline3dModel_monkeypatched(_Spline3dModel): pass model_class = _Spline3dModel_monkeypatched model_class.param_dict = param_dict multiim = False if im.ndim > 2: multiim = True axes = [int(el) for el in axes] ims, unflat_shape = seq_image_array(im, axes) pbar = tqdm(total=ims.shape[0]) else: ims = im[None] fits = [] inlierss = [] ns = [] for imi in ims: # set data yy, xx = np.indices(imi.shape) zz = imi if mask is None: keep = (np.ones_like(imi) == 1).flatten() else: keep = (mask == False).flatten() data = np.column_stack([xx.flat[keep], yy.flat[keep], zz.flat[keep]]) if type(min_samples) is int: # take number directly pass else: # take number as fraction min_samples = int(len(keep) * min_samples) print("min_samples is set to: %d" % (min_samples)) # randomly select data sel = np.random.rand(data.shape[0]) <= fract data = data[sel.flatten()] # scale residual, if chosen if scale: residual_threshold = residual_threshold * np.std(data[:, 2]) # determine if fitting to all full_fit = min_samples == data.shape[0] if not full_fit: # do ransac fit model, inliers = ransac(data=data, model_class=model_class, min_samples=min_samples, residual_threshold=residual_threshold, max_trials=max_trials) else: model = model_class() inliers = np.ones(data.shape[0]) == 1 # get params from fit with all inliers model.estimate(data[inliers]) # get model over all x, y args = (xx.flatten(), yy.flatten(), zz.flatten()) fit = model.my_model(model.params, *args).reshape(imi.shape) if mask is None and fract == 1: inliers = inliers.reshape(imi.shape) else: inliers_nans = np.empty_like(imi).flatten() inliers_nans[:] = np.nan yi = np.indices(inliers_nans.shape)[0] sel_fit = yi[keep][sel.flatten()] inliers_nans[sel_fit] = inliers inliers = inliers_nans.reshape(imi.shape) # calculate normal for plane if mode == 1: # linear C = model.params n = np.array([-C[0], -C[1], 1]) n_mag = np.linalg.norm(n, ord=None, axis=0) n = n / n_mag else: # non-linear n = None if plot: import matplotlib.pylab as plt import matplotlib as mpl from numpy.ma import masked_where from mpl_toolkits.axes_grid1 import ImageGrid plt.ion() cmap = mpl.cm.gray cmap.set_bad('r') cor_im = imi - fit pct = 0.1 vmin, vmax = np.percentile(cor_im, [pct, 100 - pct]) if widget is None: fig = plt.figure() else: docked = widget.setup_docking("Circular Center", "Bottom") fig = docked.get_fig() fig.clf() grid = ImageGrid(fig, 111, nrows_ncols=(1, 4), axes_pad=0.1, share_all=True, label_mode="L", cbar_location="right", cbar_mode="single") images = [imi, masked_where(inliers == False, imi), fit, cor_im] titles = ['Image', 'Inliers', 'Fit', 'Corrected'] for i, image in enumerate(images): img = grid[i].imshow(image, cmap=cmap, interpolation='nearest') grid[i].set_title(titles[i]) img.set_clim(vmin, vmax) grid.cbar_axes[0].colorbar(img) #f, axs = plt.subplots(1, 4, sharex=True, sharey=True) #_ = axs[0].matshow(imi, cmap=cmap) #_ = axs[1].matshow(masked_where(inliers==False, imi), cmap=cmap) #_ = axs[2].matshow(fit, cmap=cmap) #_ = axs[3].matshow(cor_im, vmin=vmin, vmax=vmax) # for i, title in enumerate(['Image' , 'Inliers', 'Fit', 'Corrected']): # axs[i].set_title(title) # plt.tight_layout() fits.append(fit) inlierss.append(inliers) ns.append(n) if multiim: pbar.update(1) fit = np.array(fits) inliers = np.array(inlierss) n = np.array(ns) if multiim: pbar.close() # reshape fit = unseq_image_array(fit, axes, unflat_shape) inliers = unseq_image_array(inliers, axes, unflat_shape) n = unseq_image_array(n, axes, unflat_shape) else: fit = fit[0] inliers = inliers[0] n = n[0] return (fit, inliers, n)
def modeldata_comparison(cube, pdffile=None, **kwargs): """This will make a six-panel plot for providing an easy overview of the the model and the data. Panels are the zeroth moment for the model, the data and the residual, the first moment for the model and the data, and the residual in the velocity field. inputs: ------- cube: Qubefit object that holds the data and the model keywords: --------- pdffile (string|default: None): If set, will save the image to a pdf file showmask (Bool|default: False): If set, it will show a contour of the mask used in creating the fit (stored in cube.maskarray). """ # create the figure plots fig = plt.figure(1, (8., 8.)) grid = ImageGrid(fig, 111, nrows_ncols=(2, 3), axes_pad=0) # create the data, model and residual cubes dqube = dc(cube) mqube = dc(cube) mqube.data = cube.model rqube = dc(cube) rqube.data = cube.data - cube.model # moment-zero data dMom0 = dqube.calculate_moment(moment=0) Mom0sig = (np.sqrt(np.nansum(cube.variance[:, 0, 0])) * cube.__get_velocitywidth__()) clevels = np.insert(np.arange(3, 30, 3), 0, np.arange(-30, 0, 3)) * Mom0sig vrangemom = [-3 * Mom0sig, 11 * Mom0sig] mask = dMom0.mask_region(value=Mom0sig * 3, applymask=False) standardfig(raster=dMom0, contour=dMom0, clevels=clevels, ax=grid[0], fig=fig, vrange=vrangemom, cmap='RdYlBu_r', text='Data', textprop=[dict(size=12)], **kwargs) # moment-zero model mMom0 = mqube.calculate_moment(moment=0) standardfig(raster=mMom0, contour=mMom0, clevels=clevels, ax=grid[1], fig=fig, vrange=vrangemom, cmap='RdYlBu_r', text='Model', textprop=[dict(size=12)], **kwargs) # moment-zero residual rMom0 = rqube.calculate_moment(moment=0) standardfig(raster=rMom0, contour=rMom0, clevels=clevels, ax=grid[2], fig=fig, vrange=vrangemom, cmap='RdYlBu_r', text='Residual', textprop=[dict(size=12)], **kwargs) # moment-one data dsqube = dqube.mask_region(mask=mask) dMom1 = dsqube.calculate_moment(moment=1) vrangemom1 = [0.95 * np.nanmin(dMom1.data), 0.95 * np.nanmax(dMom1.data)] standardfig(raster=dMom1, ax=grid[3], fig=fig, cmap='Spectral_r', vrange=vrangemom1, text='Data', textprop=[dict(size=12)], **kwargs) msqube = mqube.mask_region(mask=mask) mMom1 = msqube.calculate_moment(moment=1) standardfig(raster=mMom1, ax=grid[4], fig=fig, cmap='Spectral_r', vrange=vrangemom1, beam=False, text='Model', textprop=[dict(size=12)], **kwargs) rMom1 = dc(mMom1) rMom1.data = dMom1.data - mMom1.data standardfig(raster=rMom1, ax=grid[5], fig=fig, cmap='Spectral_r', vrange=vrangemom1, beam=False, text='Residual', textprop=[dict(size=12)], **kwargs)
class Plot(object): def __init__(self, kind='', figsize=None, nrows=1, ncols=1, rect=111, cbar_mode='single', squeeze=False, **kwargs): self._create_subplots(kind=kind, figsize=figsize, nrows=nrows, ncols=ncols, **kwargs) def _create_subplots(self, kind='', figsize=None, nrows=1, ncols=1, rect=111, cbar_mode='single', squeeze=False, **kwargs): """ :Kwargs: - kind (str, default: '') The kind of plot. For plotting matrices or images (`matplotlib.pyplot.imshow`), choose `matrix`, otherwise leave blank. - figsize (tuple, defaut: None) Size of the figure. - nrows_ncols (tuple, default: (1, 1)) Shape of subplot arrangement. - **kwargs A dictionary of keyword arguments that `matplotlib.ImageGrid` or `matplotlib.pyplot.suplots` accept. Differences: - `rect` (`matplotlib.ImageGrid`) is a keyword argument here - `cbar_mode = 'single'` - `squeeze = False` :Returns: `matplotlib.pyplot.figure` and a grid of axes. """ if 'nrows_ncols' not in kwargs: nrows_ncols = (nrows, ncols) else: nrows_ncols = kwargs['nrows_ncols'] del kwargs['nrows_ncols'] try: num = self.fig.number self.fig.clf() except: num = None if kind == 'matrix': self.fig = self.figure(figsize=figsize, num=num) self.axes = ImageGrid(self.fig, rect, nrows_ncols=nrows_ncols, cbar_mode=cbar_mode, **kwargs ) else: self.fig, self.axes = plt.subplots( nrows=nrows_ncols[0], ncols=nrows_ncols[1], figsize=figsize, squeeze=squeeze, num=num, **kwargs ) self.axes = self.axes.ravel() # turn axes into a list self.kind = kind self.subplotno = -1 # will get +1 after the plot command self.nrows_ncols = nrows_ncols return (self.fig, self.axes) def __getattr__(self, name): """Pass on a `matplotlib` function that we haven't modified """ def method(*args, **kwargs): return getattr(plt, name)(*args, **kwargs) try: return method # is it a function? except TypeError: # so maybe it's just a self variable return getattr(self, name) def __getitem__(self, key): """Allow to get axes as Plot()[key] """ if key > len(self.axes): raise IndexError if key < 0: key += len(self.axes) return self.axes[key] def get_ax(self, subplotno=None): """ Returns the current or the requested axis from the current figure. :note: The :class:`Plot()` is indexable so you should access axes as `Plot()[key]` unless you want to pass a list like (row, col). :Kwargs: subplotno (int, default: None) Give subplot number explicitly if you want to get not the current axis :Returns: ax """ if subplotno is None: no = self.subplotno else: no = subplotno if isinstance(no, int): ax = self.axes[no] else: if no[0] < 0: no += len(self.axes._nrows) if no[1] < 0: no += len(self.axes._ncols) if isinstance(self.axes, ImageGrid): # axes are a list if self.axes._direction == 'row': no = self.axes._ncols * no[0] + no[1] else: no = self.axes._nrows * no[0] + no[1] else: # axes are a grid no = self.axes._ncols * no[0] + no[1] ax = self.axes[no] return ax def next(self): """ Returns the next axis. This is useful when a plotting function is not implemented by :mod:`plot` and you have to instead rely on matplotlib's plotting which does not advance axes automatically. """ self.subplotno += 1 return self.get_ax() def sample_paired(self, ncolors=2): """ Returns colors for matplotlib.cm.Paired. """ if ncolors <= 12: colors_full = [mpl.cm.Paired(i * 1. / 11) for i in range(1, 12, 2)] colors_pale = [mpl.cm.Paired(i * 1. / 11) for i in range(10, -1, -2)] colors = colors_full + colors_pale return colors[:ncolors] else: return [mpl.cm.Paired(c) for c in np.linspace(0,ncolors)] def get_colors(self, ncolors=2, cmap='Paired'): """ Get a list of nice colors for plots. FIX: This function is happy to ignore the ugly settings you may have in your matplotlibrc settings. TODO: merge with mpltools.color :Kwargs: ncolors (int, default: 2) Number of colors required. Typically it should be the number of entries in the legend. cmap (str or matplotlib.cm, default: 'Paired') A colormap to sample from when ncolors > 12 :Returns: a list of colors """ colorc = plt.rcParams['axes.color_cycle'] if ncolors < len(colorc): colors = colorc[:ncolors] elif ncolors <= 12: colors = self.sample_paired(ncolors=ncolors) else: thisCmap = mpl.cm.get_cmap(cmap) norm = mpl.colors.Normalize(0, 1) z = np.linspace(0, 1, ncolors + 2) z = z[1:-1] colors = thisCmap(norm(z)) return colors def pivot_plot(self,df,rows=None,cols=None,values=None,yerr=None, **kwargs): agg = self.aggregate(df, rows=rows, cols=cols, values=values, yerr=yerr) if yerr is None: no_yerr = True else: no_yerr = False return self._plot(agg, no_yerr=no_yerr,**kwargs) def _plot(self, agg, ax=None, title='', kind='bar', xtickson=True, ytickson=True, no_yerr=False, numb=False, autoscale=True, **kwargs): """DEPRECATED plotting function""" print "plot._plot() has been DEPRECATED; please don't use it anymore" self.plot(agg, ax=ax, title=title, kind=kind, xtickson=xtickson, ytickson=ytickson, no_yerr=no_yerr, numb=numb, autoscale=autoscale, **kwargs) def plot(self, agg, subplots=None, **kwargs): """ The main plotting function. :Args: agg (`pandas.DataFrame` or similar) A structured input, preferably a `pandas.DataFrame`, but in principle accepts anything that can be converted into it. :Kwargs: - subplots (None, True, or False; default=None) Whether you want to split data into subplots or not. If True, the top level is treated as a subplot. If None, detects automatically based on `agg.columns.names` -- the first entry to start with `subplots.` will be used. This is the default output from `stats.aggregate` and is recommended. - **kwargs Keyword arguments for plotting :Returns: A list of axes of all plots. """ agg = pandas.DataFrame(agg) axes = [] try: s_idx = [s for s,n in enumerate(agg.columns.names) if n.startswith('subplots.')] except: s_idx = None if s_idx is not None: # subplots implicit in agg if len(s_idx) != 0: sbp = agg.columns.levels[s_idx[0]] else: sbp = None elif subplots: # get subplots from the top level column sbp = agg.columns.levels[0] else: sbp = None if sbp is None: axes = [self._plot_ax(agg, **kwargs)] else: # if haven't made any plots yet... if self.subplotno == -1: num_subplots = len(sbp) # ...can still adjust the number of subplots if num_subplots > len(self.axes): self._create_subplots(ncols=num_subplots) for no, subname in enumerate(sbp): # all plots are the same, onle legend will suffice if subplots is None or subplots: if no == 0: legend = True else: legend = False else: # plots vary; each should get a legend legend = True ax = self._plot_ax(agg[subname], title=subname, legend=legend, **kwargs) if 'title' in kwargs: ax.set_title(kwargs['title']) else: ax.set_title(subname) axes.append(ax) return axes def _plot_ax(self, agg, ax=None, title='', kind='bar', legend=True, xtickson=True, ytickson=True, no_yerr=False, numb=False, autoscale=True, order=None, **kwargs): if ax is None: self.subplotno += 1 ax = self.get_ax() if isinstance(agg, pandas.DataFrame): mean, p_yerr = self.errorbars(agg) else: mean = agg p_yerr = np.zeros((len(agg), 1)) if mean.index.nlevels == 1: # oops, nothing to unstack mean = pandas.DataFrame(mean).T p_yerr = pandas.DataFrame(p_yerr).T else: # make columns which will turn into legend entries for name in agg.columns.names: if name.startswith('cols.'): mean = mean.unstack(level=name) p_yerr = p_yerr.unstack(level=name) if isinstance(agg, pandas.Series) and kind=='bean': kind = 'bar' print 'WARNING: Beanplot not available for a single measurement' if kind == 'bar': self.barplot(mean, yerr=p_yerr, ax=ax) elif kind == 'line': self.lineplot(mean, yerr=p_yerr, ax=ax) elif kind == 'bean': autoscale = False # FIX: autoscaling is incorrect on beanplots #if len(mean.columns) <= 2: ax = self.beanplot(agg, ax=ax, order=order, **kwargs)#, pos=range(len(mean.index))) #else: #raise Exception('Beanplot is not available for more than two ' #'classes.') else: raise Exception('%s plot not recognized. Choose from ' '{bar, line, bean}.' %kind) # TODO: xticklabel rotation business is too messy if 'xticklabels' in kwargs: ax.set_xticklabels(kwargs['xticklabels'], rotation=0) if not xtickson: ax.set_xticklabels(['']*len(ax.get_xticklabels())) labels = ax.get_xticklabels() max_len = max([len(label.get_text()) for label in labels]) for label in labels: if max_len > 20: label.set_rotation(90) else: label.set_rotation(0) #label.set_size('x-large') #ax.set_xticklabels(labels, rotation=0, size='x-large') if not ytickson: ax.set_yticklabels(['']*len(ax.get_yticklabels())) ax.set_xlabel('') # set y-axis limits if 'ylim' in kwargs: ax.set_ylim(kwargs['ylim']) elif autoscale: mean_array = np.asarray(mean) r = np.max(mean_array) - np.min(mean_array) ebars = np.where(np.isnan(p_yerr), r/3., p_yerr) if kind == 'bar': ymin = np.min(np.asarray(mean) - ebars) if ymin > 0: ymin = 0 else: ymin = np.min(np.asarray(mean) - 3*ebars) else: ymin = np.min(np.asarray(mean) - 3*ebars) if kind == 'bar': ymax = np.max(np.asarray(mean) + ebars) if ymax < 0: ymax = 0 else: ymax = np.max(np.asarray(mean) + 3*ebars) else: ymax = np.max(np.asarray(mean) + 3*ebars) ax.set_ylim([ymin, ymax]) # set x and y labels if 'xlabel' in kwargs: ax.set_xlabel(kwargs['xlabel']) else: ax.set_xlabel(self._get_title(mean, 'rows')) if 'ylabel' in kwargs: ax.set_ylabel(kwargs['ylabel']) else: ax.set_ylabel(self._get_title(mean, 'cols')) # set x tick labels #FIX: data.index returns float even if it is int because dtype=object #if len(mean.index) == 1: # no need to put a label for a single bar group #ax.set_xticklabels(['']) #else: ax.set_xticklabels(mean.index.tolist()) ax.set_title(title) self._draw_legend(ax, visible=legend, data=mean, **kwargs) if numb == True: self.add_inner_title(ax, title='%s' % self.subplotno, loc=2) return ax def _get_title(self, data, pref): if pref == 'cols': dnames = data.columns.names else: dnames = data.index.names title = [n.split('.',1)[1] for n in dnames if n.startswith(pref+'.')] title = ', '.join(title) return title def _draw_legend(self, ax, visible=True, data=None, **kwargs): l = ax.get_legend() # get an existing legend if l is None: # create a new legend l = ax.legend() l.legendPatch.set_alpha(0.5) l.set_title(self._get_title(data, 'cols')) if 'legend_visible' in kwargs: l.set_visible(kwargs['legend_visible']) elif visible is not None: l.set_visible(visible) else: #decide automatically if len(l.texts) == 1: # showing a single legend entry is useless l.set_visible(False) else: l.set_visible(True) def hide_plots(self, nums): """ Hides an axis. :Args: nums (int, tuple or list of ints) Which axes to hide. """ if isinstance(nums, int) or isinstance(nums, tuple): nums = [nums] for num in nums: ax = self.get_ax(num) ax.axis('off') def barplot(self, data, yerr=None, ax=None): """ Plots a bar plot. :Args: data (`pandas.DataFrame` or any other array accepted by it) A data frame where rows go to the x-axis and columns go to the legend. """ data = pandas.DataFrame(data) if yerr is None: yerr = np.empty(data.shape) yerr = yerr.reshape(data.shape) # force this shape yerr = np.nan if ax is None: self.subplotno += 1 ax = self.get_ax() colors = self.get_colors(len(data.columns)) n = len(data.columns) idx = np.arange(len(data)) width = .75 / n rects = [] for i, (label, column) in enumerate(data.iteritems()): rect = ax.bar(idx+i*width+width/2, column, width, label=str(label), yerr=yerr[label].tolist(), color = colors[i], ecolor='black') # TODO: yerr indexing might need fixing rects.append(rect) ax.set_xticks(idx + width*n/2 + width/2) ax.legend(rects, data.columns.tolist()) return ax def lineplot(self, data, yerr=None, ax=None): """ Plots a bar plot. :Args: data (`pandas.DataFrame` or any other array accepted by it) A data frame where rows go to the x-axis and columns go to the legend. """ data = pandas.DataFrame(data) if yerr is None: yerr = np.empty(data.shape) yerr = yerr.reshape(data.shape) # force this shape yerr = np.nan if ax is None: self.subplotno += 1 ax = self.get_ax() #colors = self.get_colors(len(data.columns)) x = range(len(data)) lines = [] for i, (label, column) in enumerate(data.iteritems()): line = ax.plot(x, column, label=str(label)) lines.append(line) ax.errorbar(x, column, yerr=yerr[label].tolist(), fmt=None, ecolor='black') #ticks = ax.get_xticks().astype(int) #if ticks[-1] >= len(data.index): #labels = data.index[ticks[:-1]] #else: #labels = data.index[ticks] #ax.set_xticklabels(labels) #ax.legend() #loc='center left', bbox_to_anchor=(1.3, 0.5) #loc='upper right', frameon=False return ax def scatter(self, x, y, ax=None, labels=None, title='', **kwargs): """ Draws a scatter plot. This is very similar to `matplotlib.pyplot.scatter` but additionally accepts labels (for labeling points on the plot), plot title, and an axis where the plot should be drawn. :Args: - x (an iterable object) An x-coordinate of data - y (an iterable object) A y-coordinate of data :Kwargs: - ax (default: None) An axis to plot in. - labels (list of str, default: None) A list of labels for each plotted point - title (str, default: '') Plot title - ** kwargs Additional keyword arguments for `matplotlib.pyplot.scatter` :Return: Current axis for further manipulation. """ if ax is None: self.subplotno += 1 ax = self.get_ax() plt.rcParams['axes.color_cycle'] ax.scatter(x, y, marker='o', color=self.get_colors()[0], **kwargs) if labels is not None: for c, (pointx, pointy) in enumerate(zip(x,y)): ax.text(pointx, pointy, labels[c]) ax.set_title(title) return ax def matrix_plot(self, matrix, ax=None, title='', **kwargs): """ Plots a matrix. .. warning:: Not tested yet :Args: matrix :Kwargs: - ax (default: None) An axis to plot on. - title (str, default: '') Plot title - **kwargs Keyword arguments to pass to `matplotlib.pyplot.imshow` """ if ax is None: ax = plt.subplot(111) import matplotlib.colors norm = matplotlib.colors.normalize(vmax=1, vmin=0) mean, sem = self.errorbars(matrix) #matrix = pandas.pivot_table(mean.reset_index(), rows=) im = ax.imshow(mean, norm=norm, interpolation='none', **kwargs) # ax.set_title(title) ax.cax.colorbar(im)#, ax=ax, use_gridspec=True) # ax.cax.toggle_label(True) t = self.add_inner_title(ax, title, loc=2) t.patch.set_ec("none") t.patch.set_alpha(0.8) xnames = ['|'.join(map(str,label)) for label in matrix.minor_axis] ax.set_xticks(range(len(xnames))) ax.set_xticklabels(xnames) # rotate long labels if max([len(n) for n in xnames]) > 20: ax.axis['bottom'].major_ticklabels.set_rotation(90) ynames = ['|'.join(map(str,label)) for label in matrix.major_axis] ax.set_yticks(range(len(ynames))) ax.set_yticklabels(ynames) return ax def add_inner_title(self, ax, title, loc=2, size=None, **kwargs): from matplotlib.offsetbox import AnchoredText from matplotlib.patheffects import withStroke if size is None: size = dict(size=plt.rcParams['legend.fontsize']) at = AnchoredText(title, loc=loc, prop=size, pad=0., borderpad=0.5, frameon=False, **kwargs) ax.add_artist(at) at.txt._text.set_path_effects([withStroke(foreground="w", linewidth=3)]) return at def errorbars(self, df, yerr_type='sem'): # Set up error bar information if yerr_type == 'sem': mean = df.mean() # mean across items # std already has ddof=1 sem = df.std() / np.sqrt(len(df)) #yerr = np.array(sem)#.reshape(mean.shape) # force this shape elif yerr_type == 'binomial': pass # alpha = .05 # z = stats.norm.ppf(1-alpha/2.) # count = np.mean(persubj, axis=1, ddof=1) # p_yerr = z*np.sqrt(mean*(1-mean)/persubj.shape[1]) return mean, sem def stats_test(self, agg, test='ttest'): d = agg.shape[0] if test == 'ttest': # 2-tail T-Test ttest = (np.zeros((agg.shape[1]*(agg.shape[1]-1)/2, agg.shape[2])), np.zeros((agg.shape[1]*(agg.shape[1]-1)/2, agg.shape[2]))) ii = 0 for c1 in range(agg.shape[1]): for c2 in range(c1+1,agg.shape[1]): thisTtest = stats.ttest_rel(agg[:,c1,:], agg[:,c2,:], axis = 0) ttest[0][ii,:] = thisTtest[0] ttest[1][ii,:] = thisTtest[1] ii += 1 ttestPrint(title = '**** 2-tail T-Test of related samples ****', values = ttest, plotOpt = plotOpt, type = 2) elif test == 'ttest_1samp': # One-sample t-test m = .5 oneSample = stats.ttest_1samp(agg, m, axis = 0) ttestPrint(title = '**** One-sample t-test: difference from %.2f ****' %m, values = oneSample, plotOpt = plotOpt, type = 1) elif test == 'binomial': # Binomial test binom = np.apply_along_axis(stats.binom_test,0,agg) print binom return binom def ttestPrint(self, title = '****', values = None, xticklabels = None, legend = None, bon = None): d = 8 # check if there are any negative t values (for formatting purposes) if np.any([np.any(val < 0) for val in values]): neg = True else: neg = False print '\n' + title for xi, xticklabel in enumerate(xticklabels): print xticklabel maxleg = max([len(leg) for leg in legend]) # if type == 1: legendnames = ['%*s' %(maxleg,p) for p in plotOpt['subplot']['legend.names']] # elif type == 2: pairs = q.combinations(legend,2) legendnames = ['%*s' %(maxleg,p[0]) + ' vs ' + '%*s' %(maxleg,p[1]) for p in pairs] #print legendnames for yi, legendname in enumerate(legendnames): if values[0].ndim == 1: t = values[0][xi] p = values[1][xi] else: t = values[0][yi,xi] p = values[1][yi,xi] if p < .001/bon: star = '***' elif p < .01/bon: star = '**' elif p < .05/bon: star = '*' else: star = '' if neg and t > 0: outputStr = ' %(s)s: t(%(d)d) = %(t).3f, p = %(p).3f %(star)s' else: outputStr = ' %(s)s: t(%(d)d) = %(t).3f, p = %(p).3f %(star)s' print outputStr \ %{'s': legendname, 'd':(d-1), 't': t, 'p': p, 'star': star} def mds(self, results, labels, fonts='freesansbold.ttf', title='', ax = None): """Plots Multidimensional scaling results""" if ax is None: try: row = self.subplotno / self.axes[0][0].numCols col = self.subplotno % self.axes[0][0].numCols ax = self.axes[row][col] except: ax = self.axes[self.subplotno] ax.set_title(title) # plot each point with a name dims = results.ndim try: if results.shape[1] == 1: dims = 1 except: pass if dims == 1: df = pandas.DataFrame(results, index=labels, columns=['data']) df = df.sort(columns='data') self._plot(df) elif dims == 2: for c, coord in enumerate(results): ax.plot(coord[0], coord[1], 'o', color=mpl.cm.Paired(.5)) ax.text(coord[0], coord[1], labels[c], fontproperties=fonts[c]) else: print 'Cannot plot more than 2 dims' def _violinplot(self, data, pos, rlabels, ax=None, bp=False, cut=None, **kwargs): """ Make a violin plot of each dataset in the `data` sequence. Based on `code by Teemu Ikonen <http://matplotlib.1069221.n5.nabble.com/Violin-and-bean-plots-tt27791.html>`_ which was based on `code by Flavio Codeco Coelho <http://pyinsci.blogspot.com/2009/09/violin-plot-with-matplotlib.html>`) """ def draw_density(p, low, high, k1, k2, ncols=2): m = low #lower bound of violin M = high #upper bound of violin x = np.linspace(m, M, 100) # support for violin v1 = k1.evaluate(x) # violin profile (density curve) v1 = w*v1/v1.max() # scaling the violin to the available space v2 = k2.evaluate(x) # violin profile (density curve) v2 = w*v2/v2.max() # scaling the violin to the available space if ncols == 2: ax.fill_betweenx(x, -v1 + p, p, facecolor='black', edgecolor='black') ax.fill_betweenx(x, p, p + v2, facecolor='grey', edgecolor='gray') else: ax.fill_betweenx(x, -v1 + p, p + v2, facecolor='black', edgecolor='black') if pos is None: pos = [0,1] dist = np.max(pos)-np.min(pos) w = min(0.15*max(dist,1.0),0.5) * .5 #for major_xs in range(data.shape[1]): for num, rlabel in enumerate(rlabels): p = pos[num] d1 = data.ix[rlabel, 0] k1 = scipy.stats.gaussian_kde(d1) #calculates the kernel density if data.shape[1] == 1: d2 = d1 k2 = k1 else: d2 = data.ix[rlabel, 1] k2 = scipy.stats.gaussian_kde(d2) #calculates the kernel density cutoff = .001 if cut is None: upper = max(d1.max(),d2.max()) lower = min(d1.min(),d2.min()) stepsize = (upper - lower) / 100 area_low1 = 1 # max cdf value area_low2 = 1 # max cdf value low = min(d1.min(), d2.min()) while area_low1 > cutoff or area_low2 > cutoff: area_low1 = k1.integrate_box_1d(-np.inf, low) area_low2 = k2.integrate_box_1d(-np.inf, low) low -= stepsize #print area_low, low, '.' area_high1 = 1 # max cdf value area_high2 = 1 # max cdf value high = max(d1.max(), d2.max()) while area_high1 > cutoff or area_high2 > cutoff: area_high1 = k1.integrate_box_1d(high, np.inf) area_high2 = k2.integrate_box_1d(high, np.inf) high += stepsize else: low, high = cut draw_density(p, low, high, k1, k2, ncols=data.shape[1]) # a work-around for generating a legend for the PolyCollection # from http://matplotlib.org/users/legend_guide.html#using-proxy-artist left = Rectangle((0, 0), 1, 1, fc="black", ec='black') right = Rectangle((0, 0), 1, 1, fc="gray", ec='gray') ax.legend((left, right), data.columns.tolist()) #import pdb; pdb.set_trace() #ax.set_xlim(pos[0]-3*w, pos[-1]+3*w) #if bp: #ax.boxplot(data,notch=1,positions=pos,vert=1) return ax def _stripchart(self, data, pos, rlabels, ax=None, mean=False, median=False, width=None, discrete=True, bins=30): """Plot samples given in `data` as horizontal lines. :Kwargs: mean: plot mean of each dataset as a thicker line if True median: plot median of each dataset as a dot if True. width: Horizontal width of a single dataset plot. """ def draw_lines(d, maxcount, hist, bin_edges, sides=None): if discrete: bin_edges = bin_edges[:-1] # upper edges not needed hw = hist * w / (2.*maxcount) else: bin_edges = d hw = w / 2. ax.hlines(bin_edges, sides[0]*hw + p, sides[1]*hw + p, color='white') if mean: # draws a longer black line ax.hlines(np.mean(d), sides[0]*2*w + p, sides[1]*2*w + p, lw=2, color='black') if median: # puts a white dot ax.plot(p, np.median(d), 'o', color='white', markeredgewidth=0) #data, pos = self._beanlike_setup(data, ax) if width: w = width else: #if pos is None: #pos = [0,1] dist = np.max(pos)-np.min(pos) w = min(0.15*max(dist,1.0),0.5) * .5 #colnames = [d for d in data.columns.names if d.startswith('cols.') ] #if len(colnames) == 0: # nothing specified explicitly as a columns #try: #colnames = data.columns.levels[-1] #except: #colnames = data.columns #func1d = lambda x: np.histogram(x, bins=bins) # apply along cols hist, bin_edges = np.apply_along_axis(np.histogram, 0, data, bins) # it return arrays of object type, so we got to correct that hist = np.array(hist.tolist()) bin_edges = np.array(bin_edges.tolist()) maxcount = np.max(hist) for n, rlabel in enumerate(rlabels): p = pos[n] d = data.ix[:,rlabel] if len(d.columns) == 1: draw_lines(d.ix[:,0], maxcount, hist[0], bin_edges[0], sides=[-1,1]) else: draw_lines(d.ix[:,0], maxcount, hist[0], bin_edges[0], sides=[-1,0]) draw_lines(d.ix[:,1], maxcount, hist[1], bin_edges[1], sides=[ 0,1]) ax.set_xlim(min(pos)-3*w, max(pos)+3*w) #ax.set_xticks([-1]+pos+[1]) ax.set_xticks(pos) #import pdb; pdb.set_trace() #ax.set_xticklabels(['-1']+np.array(data.major_axis).tolist()+['1']) if len(rlabels) > 1: ax.set_xticklabels(rlabels) else: ax.set_xticklabels('') return ax def beanplot(self, data, ax=None, pos=None, mean=True, median=True, cut=None, order=None, discrete=True, **kwargs): """Make a bean plot of each dataset in the `data` sequence. Reference: http://www.jstatsoft.org/v28/c01/paper """ data_tr, pos, rlabels = self._beanlike_setup(data, ax, order) dist = np.max(pos) - np.min(pos) w = min(0.15*max(dist,1.0),0.5) * .5 ax = self._stripchart(data, pos, rlabels, ax=ax, mean=mean, median=median, width=0.8*w, discrete=discrete) ax = self._violinplot(data_tr, pos, rlabels, ax=ax, bp=False, cut=cut) return ax def _beanlike_setup(self, data, ax, order=None): data = pandas.DataFrame(data) # Series will be forced into a DataFrame data = data.unstack([n for n in data.index.names if n.startswith('yerr.')]) data = data.unstack([n for n in data.index.names if n.startswith('rows.')]) rlabels = data.columns data = data.unstack([n for n in data.index.names if n.startswith('yerr.')]) data = data.T # now rows and values are in rows, cols in cols #if len(data.columns) > 2: #raise Exception('Beanplot cannot handle more than two categories') if len(data.index.levels[-1]) <= 1: raise Exception('Cannot make a beanplot for a single observation') ## put columns at the bottom so that it's easy to iterate in violinplot #order = {'rows': [], 'cols': []} #for i,n in enumerate(data.columns.names): #if n.startswith('cols.'): #order['cols'].append(i) #else: #order['rows'].append(i) #data = data.reorder_levels(order['rows'] + order['cols'], axis=1) if ax is None: ax = self.next() #if order is None: pos = range(len(rlabels)) #else: #pos = np.lexsort((np.array(data.index).tolist(), order)) return data, pos, rlabels