def tempo_diagram(vehicledata, ratio, density, lanes):
    timesteps = vehicledata.shape[0]
    road = np.zeros((timesteps, lanes, 100), dtype=bool)
    for t, time in enumerate(vehicledata):
        for vehicle in time:
            pos = vehicle[0]
            lane = vehicle[1]
            size = vehicle[3]
            if size == 4:
	            road[t, lane:lane+2, pos-1:pos+1] = 1
            else:
	            road[t, lane, pos] = 1
    if lanes>1:
        fig = plt.figure(1)
        grid = ImageGrid(fig, 111,
                    nrows_ncols = (1, lanes), 
                    axes_pad=0.1,
                    )
        for i in range(lanes):
            STD = road[:,i,:]
            grid[i].imshow(STD, cmap="binary", interpolation="nearest")
            grid[i].set_title(r"$L_%s$" % i)
            grid[i].set_xticklabels([])
    else:
        fig = plt.figure(1)
        STD = road[:,0,:]
        grid = fig.add_subplot(111)
        grid.imshow(STD, cmap="binary", interpolation="nearest")
        grid.set_title(r"$l_%s$" % i)
        grid.set_xticklabels([])
    plt.savefig('CR.%.2f.D%.2f.png' % (ratio, density), bbox_inches="tight")
Beispiel #2
0
    def _create_subplots(self, kind='', figsize=None, nrows=1, ncols=1, rect=111,
        cbar_mode='single', squeeze=False, **kwargs):
        """
        :Kwargs:
            - kind (str, default: '')
                The kind of plot. For plotting matrices or images
                (`matplotlib.pyplot.imshow`), choose `matrix`, otherwise leave
                blank.
            - figsize (tuple, defaut: None)
                Size of the figure.
            - nrows_ncols (tuple, default: (1, 1))
                Shape of subplot arrangement.
            - **kwargs
                A dictionary of keyword arguments that `matplotlib.ImageGrid`
                or `matplotlib.pyplot.suplots` accept. Differences:
                    - `rect` (`matplotlib.ImageGrid`) is a keyword argument here
                    - `cbar_mode = 'single'`
                    - `squeeze = False`
        :Returns:
            `matplotlib.pyplot.figure` and a grid of axes.
        """

        if 'nrows_ncols' not in kwargs:
            nrows_ncols = (nrows, ncols)
        else:
            nrows_ncols = kwargs['nrows_ncols']
            del kwargs['nrows_ncols']
        try:
            num = self.fig.number
            self.fig.clf()
        except:
            num = None
        if kind == 'matrix':
            self.fig = self.figure(figsize=figsize, num=num)
            self.axes = ImageGrid(self.fig, rect,
                                  nrows_ncols=nrows_ncols,
                                  cbar_mode=cbar_mode,
                                  **kwargs
                                  )
        else:
            self.fig, self.axes = plt.subplots(
                nrows=nrows_ncols[0],
                ncols=nrows_ncols[1],
                figsize=figsize,
                squeeze=squeeze,
                num=num,
                **kwargs
                )
            self.axes = self.axes.ravel()  # turn axes into a list
        self.kind = kind
        self.subplotno = -1  # will get +1 after the plot command
        self.nrows_ncols = nrows_ncols
        return (self.fig, self.axes)
Beispiel #3
0
def plot_color_index(g_i_list=None,
                     i_list=None,
                     list_names=None,
                     limits=None,
                     savedir='./',
                     filename=None,
                     show=True,
                     title='',
                     redshifts=None):
    ''' Plots

        amp_functions = tuple of functions
        '''

    # Import external modules
    import numpy as np
    import pyfits as pf
    import matplotlib.pyplot as plt
    import matplotlib
    from mpl_toolkits.axes_grid1 import ImageGrid

    # Set up plot aesthetics
    plt.clf()
    plt.rcdefaults()
    colormap = plt.cm.gist_ncar
    color_cycle = [colormap(i) for i in np.linspace(0, 0.9, len(i_list))]
    fontScale = 12
    params = {  #'backend': .pdf',
        'axes.labelsize': fontScale,
        'axes.titlesize': fontScale,
        'text.fontsize': fontScale,
        'legend.fontsize': fontScale * 3 / 4,
        'xtick.labelsize': fontScale,
        'ytick.labelsize': fontScale,
        'font.weight': 500,
        'axes.labelweight': 500,
        'text.usetex': False,
        'figure.figsize': (6, 6),
        'axes.color_cycle': color_cycle  # colors of different plots
    }
    plt.rcParams.update(params)

    # Create figure
    fig = plt.figure()
    grid = ImageGrid(fig, (1, 1, 1),
                     nrows_ncols=(1, 1),
                     ngrids=1,
                     direction='row',
                     axes_pad=1,
                     aspect=False,
                     share_all=True,
                     label_mode='All')

    colors = ['k', 'b', 'g', 'r', 'c']
    linestyles = ['-', '--', '-.', '-', '-']
    letters = ['a', 'b']

    for i in range(1):
        ax = grid[i]

        for i in range(len(i_list)):
            ax.plot(i_list[i],
                    g_i_list[i],
                    label='%s Gyr' % list_names[i],
                    marker='s')
            if i < 2:
                for j, z in enumerate(redshifts):
                    ax.annotate('z=%s' % z,
                                xy=(i_list[i][j], g_i_list[i][j]),
                                textcoords='offset points',
                                xytext=(2, 3),
                                size=fontScale * 0.75)

        if limits is not None:
            ax.set_xlim(limits[0], limits[1])
            ax.set_ylim(limits[2], limits[3])

        # Adjust asthetics
        ax.set_xlabel(r'$M_i$ (mag)', )
        ax.set_ylabel(r'$M_g - M_i$ (mag)')
        ax.grid(True)
        ax.legend(loc='bottom right')
        ax.set_title(title)

    if filename is not None:
        plt.savefig(savedir + filename, bbox_inches='tight', dpi=600)
    if show:
        fig.show()
Beispiel #4
0
def plot_segmask_3cl_input(y_pred,
                           x_in,
                           y_true=None,
                           class_to_plot=(2, 3),
                           input_to_plot=1,
                           input_name='channel 1',
                           xtick_int=50,
                           ytick_int=50,
                           show_plt=True,
                           save_imag=True,
                           imag_name='pred_mask_input',
                           save_as='pdf'):
    '''Function to plot segmentation mask (3 classes)
       including 1 input channel,
       this is a prototype version ==> can be combined with plot_segmask_input'''

    m_temp = np.argmax(y_pred, axis=-1) + 1
    pred_mask = ((m_temp * (m_temp == class_to_plot[0])) *
                 (0.5 / class_to_plot[0])) + ((m_temp *
                                               (m_temp == class_to_plot[1])) *
                                              (1.0 / class_to_plot[1]))
    if y_true is not None:
        # Plot prediction mask, ground truth and input image (1 channel)
        g_temp = np.argmax(y_true, axis=-1) + 1
        gr_truth = ((g_temp * (g_temp == class_to_plot[0])) *
                    (0.5 / class_to_plot[0])) + (
                        (g_temp * (g_temp == class_to_plot[1])) *
                        (1.0 / class_to_plot[1]))
        grid_cmap = ['jet', 'gray', 'gray']
        grid_imag = [pred_mask, gr_truth, x_in[..., input_to_plot - 1]]
        grid_title = [r'Segmentation mask', r'Ground truth', input_name]
        fig_width = 10.5
        n_cols = 3
    else:
        # Plot prediction mask and input image (1 channel)
        grid_cmap = ['jet', 'gray']
        grid_imag = [pred_mask, x_in[..., input_to_plot - 1]]
        grid_title = [r'Segmentation mask', input_name]
        fig_width = 6.8
        n_cols = 2

    fig = plt.figure(figsize=(fig_width, 4))
    grid = ImageGrid(fig,
                     rect=[0.1, 0.07, 0.85, 0.9],
                     nrows_ncols=(1, n_cols),
                     axes_pad=0.25,
                     share_all=True)
    for i in range(0, n_cols):
        ax = grid[i]
        ax.imshow(grid_imag[i], vmin=0, vmax=1, cmap=grid_cmap[i])
        ax.set_xticks(np.arange(0, pred_mask.shape[1] + 1, xtick_int))
        ax.set_yticks(np.arange(0, pred_mask.shape[0] + 1, ytick_int))
        ax.set_xlabel(r'image width [pixel]')
        ax.set_ylabel(r'image height [pixel]')
        ax.set_title(grid_title[i])

    if show_plt == True:
        plt.show()
    if save_imag == True:
        plt.savefig(imag_name + '.' + save_as, bbox_inches='tight')
        if show_plt == False:
            # Clear memory (or matplotlib history) although the figure
            # is not shown
            plt.close()
Beispiel #5
0
def diagnostic_plots(model,
                     chainfile,
                     burnin=0.3,
                     channelmaps=True,
                     cornerplot=True,
                     momentmaps=True,
                     bestarray=False,
                     outname='model',
                     vrangecm=None,
                     vrangemom=None,
                     cmap0='RdYlBu_r',
                     cmap1='Spectral_r',
                     cmap2='copper_r',
                     maskvalue=3.0,
                     **kwargs):
    """ This program will create several diagnostic plots of a given model
    and MCMC chain. The plots are not paper-ready but are meant to understand
    how well the MCMC chain ran. Currently the following plots are generated:

    Channel maps of the data, model and residual as well as a composite file
    which shows the data with the residual overplotted as contours
    A corner plot of the given MCMC chain converted to the correct units
    (uses the corner package)
    Moment-0, 1 and 2 images of the data and the model. For the moment 0, the
    residual contours are also plotted in the moment0-data.

    input and keywords:
    model:       This is a Qube object that contains a defined model.
    chainfile:   This is a file containting an MCMC chain as a numpy.save
                 object. The shape of the object is (nwalkers, nruns, dim+1)
                 where dim is the number of variable parameters in the MCMC
                 chain and the extra slice contains the lnprobability of the
                 given parameters in that link.
    burnin:      burnin fraction of the chain (default = 30%, i.e., 0.3)
    channelmaps: If true will plot the channels maps of the data, model and
                 residual
    cornerplot:  If true will plot the corner plot of the chain file
    momentmaps:  If true will plot the zeroth, first and second moment of the
                 data and the model
    bestarray:   If true the plots will be generated from the model with the
                 highest probability if false the median parameters will be
                 chosen.
    outname:     The name of the root used to save the pdf figures
    vrangecm:    z-axis range used to plot the channelmaps
    vrangemom:   z-axis range used to plot the moment-0 channel maps
    cmap0:       colormap to use for the plotting of the moment-0 maps
    cmap1:       colormap to use for the plotting of the moment-1 maps
    cmap2:       colormap to use for the plotting of the moment-2 maps
    maskvalue:   value (in sigma) to use to generate the mask in the
                 moment-0 image which is used in higher moment images.
                 default is 3.0.
    """

    # read in the chain data file and load it in the model then regenerate
    # the model with the wanted values (either median or best)
    Chain = np.load(chainfile)
    Chain = Chain[:, int(burnin * Chain.shape[1]):, :-1]
    Chain = Chain.reshape((-1, Chain.shape[2]))
    model.get_chainresults(chainfile, burnin=burnin)
    if not bestarray:
        model.update_parameters(model.chainpar['MedianArray'])
    else:
        model.update_parameters(model.chainpar['BestArray'])
    model.create_model()

    # create the data, model and residual cubes
    dqube = dc(model)
    mqube = dc(model)
    mqube.data = model.model
    rqube = dc(model)
    rqube.data = model.data - model.model

    # make the corner plot
    if cornerplot is True:
        # convert each parameter to a physically meaningful quantity
        # units = []
        for idx, key in enumerate(model.mcmcmap):

            # get conversion factors and units for each key
            if model.initpar[key]['Conversion'] is not None:
                conversion = model.initpar[key]['Conversion'].value
            else:
                conversion = 1.
            # units.append(model.initpar[key]['Unit'].to_string())
            # quick fix for IO
            if key == 'I0':
                Chain[:, idx] = Chain[:, idx] * 1E3

            Chain[:, idx] = Chain[:, idx] * conversion
        corner.corner(Chain,
                      labels=model.mcmcmap,
                      quantiles=[0.16, 0.5, 0.84],
                      show_titles=True)

        plt.savefig(outname + '_cornerplot.pdf', format='pdf', dpi=300)
        plt.close()

    # make the channel maps
    if channelmaps is True:
        # define some global properties for all plots
        if vrangecm is None:
            vrangecm = [np.nanmin(dqube.data), np.nanmax(dqube.data)]
        sigma = np.sqrt(model.variance[:, 0, 0])
        clevels = [np.insert(np.arange(3, 30, 3), 0, -3) * i for i in sigma]

        # make the channel map for the data
        create_channelmap(raster=dqube,
                          contour=dqube,
                          clevels=clevels,
                          pdfname=outname + '_datachannelmap.pdf',
                          vrange=vrangecm,
                          **kwargs)

        # make the channel map for the model
        create_channelmap(raster=mqube,
                          contour=mqube,
                          clevels=clevels,
                          pdfname=outname + '_modelchannelmap.pdf',
                          vrange=vrangecm,
                          **kwargs)

        # make the channel map for the residual
        create_channelmap(raster=rqube,
                          contour=rqube,
                          clevels=clevels,
                          pdfname=outname + '_residualchannelmap.pdf',
                          vrange=vrangecm,
                          **kwargs)

        # make the channel map for the data with residual contours
        create_channelmap(raster=dqube,
                          contour=rqube,
                          clevels=clevels,
                          pdfname=outname + '_combinedchannelmap.pdf',
                          vrange=vrangecm,
                          **kwargs)

        plt.close()

    # make the moment maps
    if momentmaps is True:
        # create the moment-0 images
        dMom0 = dqube.calculate_moment(moment=0)
        mMom0 = mqube.calculate_moment(moment=0)
        rMom0 = rqube.calculate_moment(moment=0)

        # calculate the Mom0sig of the data and create the contour levels
        Mom0sig = (np.sqrt(np.nansum(model.variance[:, 0, 0])) *
                   model.__get_velocitywidth__())
        clevels = np.insert(np.arange(3, 30, 3), 0, -3) * Mom0sig
        if vrangemom is None:
            vrangemom = [-3 * Mom0sig, 11 * Mom0sig]
        mask = dMom0.mask_region(value=Mom0sig * maskvalue, applymask=False)

        # create the figure
        fig = plt.figure(1, (8., 8.))
        grid = ImageGrid(fig,
                         111,
                         nrows_ncols=(2, 2),
                         axes_pad=0.,
                         cbar_mode='single',
                         cbar_location='right')

        # plot the figures
        standardfig(raster=dMom0,
                    contour=dMom0,
                    clevels=clevels,
                    ax=grid[0],
                    fig=fig,
                    vrange=vrangemom,
                    cmap=cmap0,
                    text='Data',
                    textprop=[dict(size=12)],
                    **kwargs)
        standardfig(raster=mMom0,
                    contour=mMom0,
                    clevels=clevels,
                    ax=grid[1],
                    fig=fig,
                    vrange=vrangemom,
                    cmap=cmap0,
                    beam=False,
                    text='Model',
                    textprop=[dict(size=12)],
                    **kwargs)
        standardfig(raster=rMom0,
                    contour=rMom0,
                    clevels=clevels,
                    ax=grid[2],
                    fig=fig,
                    vrange=vrangemom,
                    cmap=cmap0,
                    text='Residual',
                    textprop=[dict(size=12)],
                    **kwargs)
        standardfig(raster=dMom0,
                    contour=rMom0,
                    clevels=clevels,
                    ax=grid[3],
                    fig=fig,
                    vrange=vrangemom,
                    cmap=cmap0,
                    text='Data with residual contours',
                    textprop=[dict(size=12)],
                    **kwargs)

        # plot the color bar
        norm = mpl.colors.Normalize(vmin=vrangemom[0], vmax=vrangemom[1])
        cmapo = plt.cm.ScalarMappable(cmap=cmap0, norm=norm)
        cmapo.set_array([])
        cbr = plt.colorbar(cmapo, cax=grid.cbar_axes[0])
        cbr.ax.set_ylabel('Moment-0', labelpad=-1)
        plt.savefig(outname + '_moment0.pdf', format='pdf', dpi=300)
        plt.close()

        # create the moment-1 images
        dsqube = dqube.mask_region(mask=mask)
        dMom1 = dsqube.calculate_moment(moment=1)
        msqube = mqube.mask_region(mask=mask)
        mMom1 = msqube.calculate_moment(moment=1)

        # create the figure
        fig = plt.figure(1, (8., 5.))
        grid = ImageGrid(fig,
                         111,
                         nrows_ncols=(1, 2),
                         axes_pad=0.,
                         cbar_mode='single',
                         cbar_location='right')

        # plot the figures
        vrangemom1 = [np.nanmin(dMom1.data), np.nanmax(dMom1.data)]
        standardfig(raster=dMom1,
                    ax=grid[0],
                    fig=fig,
                    cmap=cmap1,
                    vrange=vrangemom1,
                    text='Data',
                    textprop=[dict(size=12)],
                    **kwargs)
        standardfig(raster=mMom1,
                    ax=grid[1],
                    fig=fig,
                    cmap=cmap1,
                    vrange=vrangemom1,
                    beam=False,
                    text='Model',
                    textprop=[dict(size=12)],
                    **kwargs)

        # plot the color bar
        norm = mpl.colors.Normalize(vmin=vrangemom1[0], vmax=vrangemom1[1])
        cmapo = plt.cm.ScalarMappable(cmap=cmap1, norm=norm)
        cmapo.set_array([])
        cbr = plt.colorbar(cmapo, cax=grid.cbar_axes[0])
        cbr.ax.set_ylabel('Moment-1', labelpad=-1)

        plt.savefig(outname + '_moment1.pdf', format='pdf', dpi=300)
        plt.close()

        # create the moment-2 images
        dsqube = dqube.mask_region(mask=mask)
        dMom2 = dsqube.calculate_moment(moment=2)
        msqube = mqube.mask_region(mask=mask)
        mMom2 = msqube.calculate_moment(moment=2)

        # create the figure
        fig = plt.figure(1, (5., 8.))
        grid = ImageGrid(fig,
                         111,
                         nrows_ncols=(1, 2),
                         axes_pad=0.,
                         cbar_mode='single',
                         cbar_location='right')

        # plot the figures
        vrangemom2 = [np.nanmin(dMom2.data), np.nanmax(dMom2.data)]
        standardfig(raster=dMom2,
                    ax=grid[0],
                    fig=fig,
                    cmap=cmap2,
                    vrange=vrangemom2,
                    text='Data',
                    textprop=[dict(size=12)],
                    **kwargs)
        standardfig(raster=mMom2,
                    ax=grid[1],
                    fig=fig,
                    cmap=cmap2,
                    vrange=vrangemom2,
                    beam=False,
                    text='Model',
                    textprop=[dict(size=12)],
                    **kwargs)

        # plot the color bar
        norm = mpl.colors.Normalize(vmin=vrangemom2[0], vmax=vrangemom2[1])
        cmapo = plt.cm.ScalarMappable(cmap=cmap2, norm=norm)
        cmapo.set_array([])
        cbr = plt.colorbar(cmapo, cax=grid.cbar_axes[0])
        cbr.ax.set_ylabel('Moment-2', labelpad=-1)

        plt.savefig(outname + '_moment2.pdf', format='pdf', dpi=300)
        plt.close()
Beispiel #6
0
    def _savegif(
        self,
        stems: List[str],
        imgs: NDArray,
        masks: NDArray,
        reconstructed_imgs: NDArray,
        amaps: NDArray,
    ) -> None:

        os.mkdir("results")
        pbar = tqdm(enumerate(
            zip(stems, imgs, masks, reconstructed_imgs, amaps)),
                    desc="savegif")
        for i, (stem, img, mask, reconstructed_img, amap) in pbar:

            # How to get two subplots to share the same y-axis with a single colorbar
            # https://stackoverflow.com/a/38940369
            grid = ImageGrid(
                fig=plt.figure(figsize=(16, 4)),
                rect=111,
                nrows_ncols=(1, 4),
                axes_pad=0.15,
                share_all=True,
                cbar_location="right",
                cbar_mode="single",
                cbar_size="5%",
                cbar_pad=0.15,
            )

            grid[0].imshow(img, cmap="gray")
            grid[0].tick_params(labelbottom=False,
                                labelleft=False,
                                bottom=False,
                                left=False)
            grid[0].set_title("Input Image", fontsize=20)

            grid[1].imshow(reconstructed_img, cmap="gray")
            grid[1].tick_params(labelbottom=False,
                                labelleft=False,
                                bottom=False,
                                left=False)
            grid[1].set_title("Reconstructed Image", fontsize=20)

            grid[2].imshow(img, cmap="gray")
            grid[2].imshow(mask, alpha=0.3, cmap="Reds")
            grid[2].tick_params(labelbottom=False,
                                labelleft=False,
                                bottom=False,
                                left=False)
            grid[2].set_title("Ground Truth", fontsize=20)

            grid[3].imshow(img, cmap="gray")
            im = grid[3].imshow(amap, alpha=0.3, cmap="jet", vmin=0, vmax=1)
            grid[3].tick_params(labelbottom=False,
                                labelleft=False,
                                bottom=False,
                                left=False)
            grid[3].cax.toggle_label(True)
            grid[3].set_title("Anomaly Map", fontsize=20)

            plt.colorbar(im, cax=grid.cbar_axes[0])
            plt.savefig(f"results/{stem}.png", bbox_inches="tight")
            plt.close()

        # NOTE(inoue): The gif files converted by PIL or imageio were low-quality.
        #              So, I used the conversion command (ImageMagick) instead.
        subprocess.run("convert -delay 100 -loop 0 results/*.png result.gif",
                       shell=True)
with torch.no_grad():
    reco = hm.reconstruct(samp)

    reco_im = torch.squeeze(reco).reshape(28, 28)
    samp_im = torch.squeeze(samp).reshape(28, 28)

plt.imshow(samp_im)
plt.show()
plt.imshow(reco_im)
plt.show()

# <codecell>
with torch.no_grad():
    samp = hm.sample(25)
    samp_im = torch.squeeze(samp).reshape(25, 28, 28)

fig = plt.figure(figsize=(10, 10))
grid = ImageGrid(
    fig,
    111,  # similar to subplot(111)
    nrows_ncols=(5, 5),  # creates 2x2 grid of axes
    axes_pad=0.1,  # pad between axes in inch.
)

for ax, im in zip(grid, samp_im):
    ax.imshow(im)

fig.suptitle('Sample faces drawn from HM')
plt.show()

# TODO: debug same image problem
Beispiel #8
0
def make_gallery(files=None,
                 ims=None,
                 heads=None,
                 outname=None,
                 pfovs=None,
                 scale="lin",
                 norm=None,
                 absmin=None,
                 absmax=None,
                 permax=None,
                 permin=None,
                 papercol=2,
                 ncols=4,
                 pwidth=None,
                 nrows=None,
                 view_as=None,
                 view_px=None,
                 inv=True,
                 cmap='gist_heat',
                 colorbar=None,
                 cbarwidth=0.03,
                 cbarlabel=None,
                 cbarlabelpad=None,
                 xlabel=None,
                 ylabel=None,
                 xlabelpad=None,
                 ylabelpad=None,
                 titles=None,
                 titcols='black',
                 titx=0.5,
                 tity=0.95,
                 titvalign='top',
                 tithalign='center',
                 subtitles=None,
                 subtitcols='black',
                 subtitx=0.5,
                 subtity=0.05,
                 subtitvalign='bottom',
                 subtithalign='center',
                 plotsym=None,
                 contours=None,
                 alphamap=None,
                 cbartickinterval=None,
                 sbarlength=None,
                 sbarunit='as',
                 sbarcol='black',
                 sbarthick='2',
                 sbarpos=[0.1, 0.05],
                 majtickinterval=None,
                 verbose=False,
                 latex=True,
                 texts=None,
                 textcols='black',
                 textx=0.05,
                 texty=0.95,
                 textvalign='top',
                 texthalign='left',
                 textsize=None,
                 replace_NaNs=None,
                 smooth=None,
                 lines=None,
                 latexfontsize=16,
                 axes_pad=0.05):
    """
    MISSING:
        - implementation to rotate images to North up (and East to the left
          if desired and necessary

    The purpose of this routine is to create puplication-quality multiplots of
    images with a maximum number of customisability. The following parameters
    can be set either for all images a single variable or as an array with the
    same number of elements to set the parameters individually
    INPUT:
        - flles : list of fits files to be plotted
        - pfovs : pixel size in arcsec for the images. If not provided then it
                  will be looked for in the fits headers
        - log : enable logarithmic scaling
        - norm: provide the index of the image that should be normalised to
        - permax: set a percentile value which should be used as the maximum
                  in the colormap
        - papercol: either 1 for a plot fitting into one column in the paper or
                    2 in case the plot is supposed to go over the full page
        - ncols: number of columns of images in the multiplot
        - nrows: number of rows of images in the multiplot
        - view_as: size of the field of view that should be plotted in arcsec
        - inv: if set true then the used colormap is inverted
        - cmap: colormap to be used
        - xlabelpad, xlabelpad : adjust the position of the x,y axis label
        - titles: provide titles for the individual images
        - titcols: colors for the titles
        - titx, tity: x,y position of the title
        - titvalign, tithalign: vertical and horizontal alignment of the title
                                text with respect to the given position
        - subtitles: similar to titles but for another text in the image


    """

    # --- read in the images in a 2D list
    if ims is None:

        if type(files) == list:
            n = len(files)

            if nrows is None and ncols is not None:
                nrows = int(np.ceil(1.0 * n / ncols))

            if ncols is None and nrows is not None:
                ncols = int(np.ceil(1.0 * n / nrows))

            # --- reshape file list into 2D array
            rest = n % nrows
            for i in range((nrows - rest) % nrows):
                files.append(None)

            files = np.reshape(files, (nrows, ncols))

        elif type(files) == np.ndarray:

            s = np.shape(files)
            if len(s) > 1:
                if nrows is None:
                    nrows = s[0]
                if ncols is None:
                    ncols = s[1]
                n = nrows * ncols

        else:
            n = 1
            nrows = 1
            ncols = 1

            files = np.reshape(files, (-1, ncols))

        # --- create a fill a 2D image array
#        ims = [[None]*ncols]*nrows
#        ims = [None]*nrows*ncols
        ims = [[None] * ncols for _ in range(nrows)]
        #print(np.shape(ims))

        for r in range(nrows):
            for c in range(ncols):
                if files[r][c] != None:
                    if files[r][c] != "":
                        i = ncols * r + c
                        ims[r][c] = fits.getdata(files[r][c], ext=0)

    # images are provides
    else:

        s = np.shape(ims)

        if type(ims) == list or (type(ims) == np.ndarray and len(s) < 4):
            n = len(ims)

            if nrows is None and ncols is not None:
                nrows = int(np.ceil(1.0 * n / ncols))

            if ncols is None and nrows is not None:
                ncols = int(np.ceil(1.0 * n / nrows))

            # --- reshape file list into 2D array
            rest = n % nrows
            for i in range((nrows - rest) % nrows):
                ims.append(None)

            ims_old = np.copy(ims)
            ims = [[None] * ncols for _ in range(nrows)]

            for r in range(nrows):
                for c in range(ncols):
                    i = ncols * r + c
                    ims[r][c] = ims_old[i]

        elif type(ims) == np.ndarray:

            if len(s) > 1:
                if nrows is None:
                    nrows = s[0]
                if ncols is None:
                    ncols = s[1]
                n = nrows * ncols

        else:
            n = 1
            nrows = 1
            ncols = 1

            ims = [[ims]]

    if heads is None and files is not None:
        heads = [[None] * ncols for _ in range(nrows)]
        for r in range(nrows):
            for c in range(ncols):
                if files[r][c] != None:
                    if files[r][c] != "":
                        heads[r][c] = fits.getheader(files[r][c], ext=0)
#    print(np.shape(heads), len(np.shape(heads)))
    if heads is not None:
        if len(np.shape(heads)) == 1:
            heads = np.full((nrows, ncols), heads, dtype=object)

    if pfovs is None and heads is not None:
        pfovs = [[None] * ncols for _ in range(nrows)]
        for r in range(nrows):
            for c in range(ncols):
                if heads[r][c] != None:
                    if heads[r][c] != "":
                        if "PFOV" in heads[r][c]:
                            pfovs[r][c] = float(heads[r][c]['PFOV'])
                        elif "HIERARCH ESO INS PFOV" in heads[r][c]:
                            pfovs[r][c] = float(
                                heads[r][c]["HIERARCH ESO INS PFOV"])
                        elif "CDELT1" in heads[r][c]:
                            pfovs[r][c] = np.abs(heads[r][c]["CDELT1"]) * 3600
        if len(pfovs) == 0:
            pfovs = None

    else:
        pfovs = reshape_input_param(pfovs, nrows, ncols)

    scale = reshape_input_param(scale, nrows, ncols)
    permin = reshape_input_param(permin, nrows, ncols)
    permax = reshape_input_param(permax, nrows, ncols)
    absmin = reshape_input_param(absmin, nrows, ncols)
    absmax = reshape_input_param(absmax, nrows, ncols)
    norm = reshape_input_param(norm, nrows, ncols)
    smooth = reshape_input_param(smooth, nrows, ncols)
    titles = reshape_input_param(titles, nrows, ncols)
    titcols = reshape_input_param(titcols, nrows, ncols)
    texts = reshape_input_param(texts, nrows, ncols)
    textcols = reshape_input_param(textcols, nrows, ncols)
    subtitles = reshape_input_param(subtitles, nrows, ncols)
    subtitcols = reshape_input_param(subtitcols, nrows, ncols)
    sbarlength = reshape_input_param(sbarlength, nrows, ncols)

    if lines is not None:
        lines = reshape_input_param(lines, nrows, ncols)

    if verbose:

        #        ny, nx = ims[0].shape
        print("MAKE_GALLERY: nrows, ncols: ", nrows, ncols)
        print("MAKE_GALLERY: n: ", n)
        print("MAKE_GALLERY: pfovs: ", pfovs)
        print("MAKE_GALLERY: scales: ", scale)
#       print("MAKE_GALLERY: nx, ny: ", nx, ny)

# --- set up the plotting configuration to use latex

    if latex:
        mpl.rcdefaults()

        mpl.rc(
            'font', **{
                'family': 'sans-serif',
                'serif': ['Computer Modern Serif'],
                'sans-serif': ['Helvetica'],
                'size': latexfontsize,
                'weight': 500,
                'variant': 'normal'
            })

        mpl.rc('axes', **{'labelweight': 'normal', 'linewidth': 1.5})
        mpl.rc('ytick', **{'major.pad': 8, 'color': 'k'})
        mpl.rc('xtick', **{'major.pad': 8})
        mpl.rc(
            'mathtext', **{
                'default': 'regular',
                'fontset': 'cm',
                'bf': 'monospace:bold'
            })

        mpl.rc('text', **{'usetex': True})
        mpl.rc('text.latex',preamble=r'\usepackage{cmbright},\usepackage{relsize},'+\
                                    r'\usepackage{upgreek}, \usepackage{amsmath}'+\
                                    r'\usepackage{bm}')

        mpl.rc('contour', **{'negative_linestyle': 'solid'})  # dashed | solid

    plt.clf()

    # 14.17
    # 6.93
    if pwidth == None:
        if papercol == 1:
            pwidth = 6.93
        elif papercol == 2:
            pwidth = 14.17
        else:
            pwidth = 7 * papercol

    subpwidth = pwidth / ncols

    if verbose:
        print("MAKE_GALLERY: pwidth, subpwidth: ", pwidth, subpwidth)

    fig = plt.figure(figsize=(pwidth, nrows * subpwidth))
    # fig.subplots_adjust(bottom=0.2)
    # fig.subplots_adjust(left=0.2)

    if inv:
        cmap = cmap + '_r'

    grid = ImageGrid(fig,
                     111,
                     nrows_ncols=(nrows, ncols),
                     axes_pad=axes_pad,
                     aspect=True)

    handles = []
    ahandles = []
    vmin = np.zeros((nrows, ncols))
    vmax = np.zeros((nrows, ncols))
    amin = np.zeros((nrows, ncols))
    amax = np.zeros((nrows, ncols))

    # --- main plotting loop
    for r in range(nrows):
        for c in range(ncols):

            i = ncols * r + c

            if verbose:
                print("MAKE_GALLERY: r,c,i", r, c, i)
                print(" - smooth: ", smooth[r][c])

            if np.shape(ims[r][c]) == ():
                continue

            if view_as is not None:
                view_px = view_as / pfovs[r][c]
                im = _crop_image(ims[r][c], box=np.ceil(view_px / 2.0) * 2 + 2)
            else:
                im = np.copy(ims[r][c])

            #print(r,c, np.nanmax(ims[r][c]), np.nanmax(im))

            if smooth[r][c]:
                im = gaussian_filter(im, sigma=smooth[r][c], mode='nearest')

            if norm[r][c]:
                im[0, 0] = np.nanmax(ims[int(norm[r][c])])

            # --- adjust the cut levels
            if permin[r][c]:
                vmin[r][c] = np.nanpercentile(im, float(permin[r][c]))
            else:
                vmin[r][c] = np.nanmin(im)

            if permax[r][c]:
                vmax[r][c] = np.nanpercentile(im, float(permax[r][c]))
            else:
                vmax[r][c] = np.nanmax(im)

            if absmax[r][c] is not None:
                vmax[r][c] = absmax[r][c]

            if absmin[r][c] is not None:
                vmin[r][c] = absmin[r][c]

            sim = np.copy(im)
            sim[im < vmin[r][c]] = vmin[r][c]
            sim[im > vmax[r][c]] = vmax[r][c]

            if replace_NaNs is not None:
                idn = np.nonzero(np.isnan(sim))
                #            print(i, len(idn), idn[0])
                if len(idn) == 0:
                    continue
                elif replace_NaNs == "min":
                    sim[np.ix_(idn[0], idn[1])] = vmin[r][c]
                elif replace_NaNs == "max":
                    sim[np.ix_(idn[0], idn[1])] = vmax[r][c]
                else:
                    sim[np.ix_(idn[0], idn[1])] = replace_NaNs

            # --- logarithmic scaling?
            if scale[r][c] == "log":
                sim = np.log10(1000.0 * (sim - vmin[r][c]) /
                               (vmax[r][c] - vmin[r][c]) + 1)

            if verbose:
                print("MAKE_GALLERY: scale[r][c]", scale[r][c])
                print("MAKE_GALLERY: vmin[r][c], vmax[r][c]", vmin[r][c],
                      vmax[r][c])

            handle = grid[i].imshow(sim,
                                    cmap=cmap,
                                    origin='lower',
                                    interpolation='nearest')

            ny, nx = im.shape

            if verbose:
                print("ny, nx:", ny, nx)

            if pfovs is not None:

                xmin = (nx / 2 - 1) * pfovs[r][c]
                xmax = -nx / 2 * pfovs[r][c]
                ymin = -ny / 2 * pfovs[r][c]
                ymax = (ny / 2 - 1) * pfovs[r][c]
                if xlabel is None:
                    xlabel = 'RA offset ["]'
                if ylabel is None:
                    ylabel = 'DEC offset ["]'

            else:

                xmax = (nx / 2 - 1)
                xmin = -nx / 2
                ymin = -ny / 2
                ymax = (ny / 2 - 1)
                if xlabel is None:
                    xlabel = 'x offset [px]'
                if ylabel is None:
                    ylabel = 'y offset [px]'

            # print(pfovs[r][c])
            # pdb.set_trace()
            # set the extent of the image
            extent = [xmin, xmax, ymin, ymax]

            if verbose:
                print("extent:", extent)
                print("x/ylabel:", xlabel, ylabel)

            handle.set_extent(extent)

            # --- optionally overplot a second map using transparency
            if alphamap is not None:

                uim = np.copy(alphamap[r][c]['im'])

                # --- determine the levels
                if alphamap[r][c]['min'] is None:
                    amin[r][c] = np.nanmin(im)
                elif str(alphamap[r][c]['min']) == 'vmax':
                    amin[r][c] = vmax[r][c]
                elif str(alphamap[r][c]['min']) == 'vmin':
                    amin[r][c] = vmin[r][c]
                else:
                    amin[r][c] = np.nanpercentile(im,
                                                  float(alphamap[r][c]['min']))

                if alphamap[r][c]['max'] is None:
                    amax[r][c] = np.nanmax(im)
                elif str(alphamap[r][c]['max']) == 'vmax':
                    amax[r][c] = vmax[r][c]
                elif str(alphamap[r][c]['max']) == 'vmax':
                    amax[r][c] = vmin[r][c]
                else:
                    amax[r][c] = np.nanpercentile(im,
                                                  float(alphamap[r][c]['max']))

                uim[uim < amin[r][c]] = amin[r][c]
                uim[uim > amax[r][c]] = amax[r][c]

                cmap2 = _create_alpha_colmap(alphamap[r][c]['mincolor'],
                                             alphamap[r][c]['maxcolor'],
                                             alphamap[r][c]['minalpha'],
                                             alphamap[r][c]['maxalpha'])

                if alphamap[r][c]['log']:
                    unorm = LogNorm()
    #                uim = np.log10(1000.0 * (uim - amin[r][c]) /
    #                          (amax[r][c] - amin[r][c]) + 1)

                else:
                    unorm = None

                ahandle = grid[i].imshow(uim,
                                         cmap=cmap2,
                                         origin='lower',
                                         interpolation='nearest',
                                         norm=unorm)

                ahandle.set_extent(extent)
                ahandles.append(ahandle)

            # --- optionally draw contours
            if contours is not None:

                # --- determine the levels
                if contours[r][c]['min'] is None:
                    cmin = np.nanmin(im)
                elif str(contours[r][c]['min']) == 'vmax':
                    cmin = vmax[r][c]
                elif str(contours[r][c]['min']) == 'vmin':
                    cmin = vmin[r][c]
                else:
                    cmin = contours[r][c]['min']

                if contours[r][c]['max'] is None:
                    cmax = np.nanmax(im)
                elif str(contours[r][c]['max']) == 'vmax':
                    cmax = vmax[r][c]
                elif str(contours[r][c]['max']) == 'vmax':
                    cmax = vmin[r][c]
                else:
                    cmax = contours[r][c]['max']

                if contours[r][c]['nstep'] is None:
                    nstep = 10
                else:
                    nstep = contours[r][c]['nstep']

                if contours[r][c]['stepsize'] is None:
                    stepsize = int(cmax - cmin / nstep)
                else:
                    stepsize = contours[r][c]['stepsize']

                levels = np.arange(cmin, cmax, stepsize)

                cont = grid[i].contour(im,
                                       levels=levels,
                                       origin='lower',
                                       colors=contours[r][c]['color'],
                                       linewidth=contours[r][c]['linewidth'],
                                       extent=extent)

                grid[i].clabel(cont,
                               levels[1::contours[r][c]['labelinterval']],
                               inline=1,
                               fontsize=contours[r][c]['labelsize'],
                               fmt=contours[r][c]['labelfmt'])

            if pfovs is not None and view_as is not None:
                xmin = 0.5 * view_as
                xmax = -0.5 * view_as
                ymin = -0.5 * view_as
                ymax = 0.5 * view_as

            elif view_px is not None:
                xmin = -0.5 * view_px
                xmax = 0.5 * view_px
                ymin = -0.5 * view_px
                ymax = 0.5 * view_px

            if verbose:
                print("xmin,xmax,ymin,ymax:", xmin, xmax, ymin, ymax)

            grid[i].set_ylim(ymin, ymax)
            grid[i].set_xlim(xmin, xmax)

            # --- optionally draw some lines
            if lines is not None:

                nlin = len(lines[r][c]["length"])
                for l in range(nlin):

                    # --- provided unit for lines is arcsec?
                    #                print(lines[r][c]["length"][l])
                    if lines[r][c]["unit"][l] == "px":
                        lines[r][c]["length"][l] *= pfovs[r][c]
                        lines[r][c]["xoff"][l] *= pfovs[r][c]
                        lines[r][c]["yoff"][l] *= pfovs[r][c]

    #                print(lines[r][c]["length"][l])

                    ang_rad = (90 - lines[r][c]["pa"][l]) * np.pi / 180.0
                    hlength = 0.5 * lines[r][c]["length"][l]

                    lx = [
                        lines[r][c]["xoff"][l] - hlength * np.cos(ang_rad),
                        lines[r][c]["xoff"][l] + hlength * np.cos(ang_rad)
                    ]

                    ly = [
                        lines[r][c]["yoff"][l] - hlength * np.sin(ang_rad),
                        lines[r][c]["yoff"][l] + hlength * np.sin(ang_rad)
                    ]

                    grid[i].plot(lx,
                                 ly,
                                 color=lines[r][c]["color"][l],
                                 linestyle=lines[r][c]["style"][l],
                                 marker='',
                                 linewidth=lines[r][c]["thick"][l],
                                 alpha=lines[r][c]["alpha"][l])

                    #                theta1 = lines[r][c]["pa"][l] - 10
                    #                theta2 = lines[r][c]["pa"][l] + 10

                    # --- angular error bars
                    if lines[r][c]["pa_err"][l] > 0:

                        theta1 = 90 - lines[r][c]["pa"][l] - lines[r][c][
                            "pa_err"][l]
                        theta2 = 90 - lines[r][c]["pa"][l] + lines[r][c][
                            "pa_err"][l]

                        parc = Arc(
                            (lines[r][c]["xoff"][l], lines[r][c]["yoff"][l]),
                            width=2 * hlength,
                            height=2 * hlength,
                            color=lines[r][c]["color"][l],
                            linewidth=lines[r][c]["thick"][l],
                            angle=0,
                            theta1=theta1,
                            theta2=theta2,
                            alpha=lines[r][c]["alpha"][l])

                        grid[i].add_patch(parc)

                        theta1 = -90 - lines[r][c]["pa"][l] - lines[r][c][
                            "pa_err"][l]
                        theta2 = -90 - lines[r][c]["pa"][l] + lines[r][c][
                            "pa_err"][l]

                        parc = Arc(
                            (lines[r][c]["xoff"][l], lines[r][c]["yoff"][l]),
                            width=2 * hlength,
                            height=2 * hlength,
                            color=lines[r][c]["color"][l],
                            linewidth=lines[r][c]["thick"][l],
                            angle=0,
                            theta1=theta1,
                            theta2=theta2,
                            alpha=lines[r][c]["alpha"][l])

                        grid[i].add_patch(parc)

            # --- optionally draw some symbols
            if plotsym is not None:
                #             if hasattr(plotsym[r][c]['x'], "__len__"):
                nsymplots = len(plotsym[r][c]['x'])
                print('nsymplots: ', nsymplots)
                for j in range(nsymplots):
                    #                grid[i].plot(plotsym[r][c]['x'][j], plotsym[r][c]['y'][j],
                    #                             marker=plotsym[r][c]['marker'][j],
                    #                             markersize=plotsym[r][c]['markersize'][j],
                    #                             markeredgewidth=plotsym[r][c]['markeredgewidth'][j],
                    #                             fillstyle=plotsym[r][c]['fillstyle'][j],
                    #                             alpha=plotsym[r][c]['alpha'][j],
                    #                             markerfacecolor=plotsym[r][c]['markerfacecolor'][j],
                    #                             markeredgecolor=plotsym[r][c]['markeredgecolor'][j])

                    if plotsym[r][c]['marker'][j] == 'ch':
                        marker = _crosshair_marker()
                    elif plotsym[r][c]['marker'][j] == 'ch45':
                        marker = _crosshair_marker(pa=45)
                    else:
                        marker = plotsym[r][c]['marker'][j]

                    grid[i].scatter(plotsym[r][c]['x'][j],
                                    plotsym[r][c]['y'][j],
                                    marker=marker,
                                    s=plotsym[r][c]['size'][j],
                                    linewidths=plotsym[r][c]['linewidth'][j],
                                    alpha=plotsym[r][c]['alpha'][j],
                                    facecolor=plotsym[r][c]['color'][j],
                                    edgecolors=plotsym[r][c]['edgecolor'][j],
                                    linestyle=plotsym[r][c]['linestyle'][j])

                    if plotsym[r][c]['label'][j] is not None:
                        grid[i].text(
                            plotsym[r][c]['x'][j][0] +
                            plotsym[r][c]['labelxoffset'][j],
                            plotsym[r][c]['y'][j][0] +
                            plotsym[r][c]['labelyoffset'][j],
                            plotsym[r][c]['label'][j],
                            color=plotsym[r][c]['labelcolor'][j],
                            verticalalignment=plotsym[r][c]['labelvalign'][j],
                            horizontalalignment=plotsym[r][c]['labelhalign']
                            [j])

            # --- ticks
            if majtickinterval is None:

                majorLocator = MultipleLocator(1)

                if ymax - ymin < 2:
                    majorLocator = MultipleLocator(0.5)
                    minorLocator = AutoMinorLocator(5)

                elif (ymax - ymin > 10) & (ymax - ymin <= 20):
                    majorLocator = MultipleLocator(2)

                elif (ymax - ymin > 20) & (ymax - ymin <= 100):
                    majorLocator = MultipleLocator(5)

                elif ymax - ymin > 100:
                    majorLocator = MultipleLocator(10)

            else:
                majorLocator = MultipleLocator(majtickinterval)

            minorLocator = AutoMinorLocator(10)
            if ymax - ymin < 2:
                minorLocator = AutoMinorLocator(5)

            grid[i].xaxis.set_minor_locator(minorLocator)
            grid[i].xaxis.set_major_locator(majorLocator)
            grid[i].yaxis.set_minor_locator(minorLocator)
            grid[i].yaxis.set_major_locator(majorLocator)
            grid[i].yaxis.set_tick_params(width=1.5, which='both')
            grid[i].xaxis.set_tick_params(width=1.5, which='both')
            grid[i].xaxis.set_tick_params(length=6)
            grid[i].yaxis.set_tick_params(length=6)
            grid[i].xaxis.set_tick_params(length=3, which='minor')
            grid[i].yaxis.set_tick_params(length=3, which='minor')

            # --- text
            if titles[r][c]:
                #   pdb.set_trace()

                if verbose:
                    print("titles[r][c]", titles[r][c])

                grid[i].set_title(titles[r][c],
                                  x=titx,
                                  y=tity,
                                  color=titcols[r][c],
                                  verticalalignment=titvalign,
                                  horizontalalignment=tithalign)

            if subtitles[r][c]:

                if verbose:
                    print("subtitles[r][c]", subtitles[r][c])

                grid[i].text(subtitx,
                             subtity,
                             subtitles[r][c],
                             color=subtitcols[r][c],
                             transform=grid[i].transAxes,
                             verticalalignment=subtitvalign,
                             horizontalalignment=subtithalign)

            if texts[r][c]:
                if verbose:
                    print("texts[r][c]", texts[r][c])
                grid[i].text(textx,
                             texty,
                             texts[r][c],
                             color=textcols[r][c],
                             transform=grid[i].transAxes,
                             verticalalignment=textvalign,
                             horizontalalignment=texthalign,
                             fontsize=textsize)

            # --- scale bar for size comparison
            if sbarlength[r][c]:
                if sbarunit == 'px':
                    sbarlength[r][c] = sbarlength[r][c] * pfovs[r][c]
                sx = [
                    sbarpos[0] - 0.5 * sbarlength[r][c],
                    sbarpos[0] + 0.5 * sbarlength[r][c]
                ]
                sy = [sbarpos[1], sbarpos[1]]
                grid[i].plot(sx,
                             sy,
                             linewidth=sbarthick,
                             color=sbarcol,
                             transform=grid[i].transAxes)

            handles.append(handle)

    # --- get the extent of the largest box containing all the axes/subplots
    extents = np.array([bla.get_position().extents for bla in grid])
    bigextents = np.empty(4)
    bigextents[:2] = extents[:, :2].min(axis=0)
    bigextents[2:] = extents[:, 2:].max(axis=0)

    # --- distance between the external axis and the text
    if xlabelpad is None and papercol == 2:
        xlabelpad = 0.03
    elif xlabelpad is None and papercol == 1:
        xlabelpad = 0.15
    elif xlabelpad is None:
        xlabelpad = 0.15 / papercol * 0.5

    if ylabelpad is None and papercol == 2:
        ylabelpad = 0.055
    elif ylabelpad is None and papercol == 1:
        ylabelpad = 0.1
    elif ylabelpad is None:
        ylabelpad = 0.1 / papercol

    if verbose:
        print("xlabelpad,ylabelpad: ", xlabelpad, ylabelpad)

    # --- text to mimic the x and y label. The text is positioned in
    #     the middle
    fig.text((bigextents[2] + bigextents[0]) / 2,
             bigextents[1] - xlabelpad,
             xlabel,
             horizontalalignment='center',
             verticalalignment='bottom')

    fig.text(bigextents[0] - ylabelpad, (bigextents[3] + bigextents[1]) / 2,
             ylabel,
             rotation='vertical',
             horizontalalignment='left',
             verticalalignment='center')

    # --- now colorbar business:
    if colorbar is not None:
        # first draw the figure, such that the axes are positionned
        fig.canvas.draw()
        #create new axes according to coordinates of image plot
        trans = fig.transFigure.inverted()

    if colorbar == 'row':
        for i in range(nrows):
            g = grid[i * ncols - 1].bbox.transformed(trans)

            if alphamap[ncols, i] is not None:

                height = 0.5 * g.height
                pos = [
                    g.x1 + g.width * 0.02, g.y0 + height, g.width * cbarwidth,
                    height
                ]
                cax = fig.add_axes(pos)

                if alphamap[ncols, i]['log']:
                    formatter = LogFormatter(10, labelOnlyBase=False)
                    cb = plt.colorbar(ticks=[
                        0.1, 0.2, 0.5, 1, 5, 10, 20, 50, 100, 200, 500, 1000,
                        2000, 5000, 10000
                    ],
                                      format=formatter,
                                      mappable=ahandles[i * ncols - 1],
                                      cax=cax)
                else:
                    cb = plt.colorbar(mappable=ahandles[i * ncols - 1],
                                      cax=cax)

                #majorLocator = MultipleLocator(alphamap[i*ncols - 1]['cbartickinterval'])
                #cb.ax.yaxis.set_major_locator(majorLocator)

#                if alphamap[i*ncols - 1]['log']:
#                    oldlabels = cb.ax.get_yticklabels()
#                    oldlabels = np.array([float(x.get_text().replace('$','')) for x in oldlabels])
#                    newlabels = 0.001*(10.0**oldlabels -1)*(amax[i*ncols - 1] - amin[i*ncols - 1]) + amin[i*ncols - 1]
#                    newlabels = [str(x)[:4] for x in newlabels]
#                    cb.ax.set_yticklabels(newlabels)

            else:

                height = g.height

            # --- pos = [left, bottom, width, height]
            pos = [g.x1 + g.width * 0.02, g.y0, g.width * cbarwidth, height]
            cax = fig.add_axes(pos)

            if (vmax[ncols, i] - vmin[ncols, i]) < 10:
                decimals = 1
            else:
                decimals = 0
            if (vmax[ncols, i] - vmin[ncols, i]) < 1:
                decimals = 2

            ticks = np.arange(vmin[ncols, i] + 1.0 / 10**decimals,
                              vmax[ncols, i], cbartickinterval)

            ticks = np.round(ticks, decimals=decimals)

            cb = plt.colorbar(ticks=ticks,
                              mappable=handles[i * ncols - 1],
                              cax=cax)

            # majorLocator = MultipleLocator(cbartickinterval)
            # cb.ax.yaxis.set_major_locator(majorLocator)

            if scale[ncols, i] == "log":
                oldlabels = cb.ax.get_yticklabels()
                oldlabels = np.array(
                    [float(x.get_text().replace('$', '')) for x in oldlabels])
                newlabels = 0.001 * (10.0**oldlabels - 1) * (
                    vmax[i * ncols - 1] - vmin[ncols, i]) + vmin[ncols, i]
                newlabels = [str(x)[:4] for x in newlabels]
                cb.ax.set_yticklabels(newlabels)

    elif colorbar == 'single':

        gb = grid[-1].bbox.transformed(trans)
        gt = grid[0].bbox.transformed(trans)

        if alphamap is not None:
            height = 0.5 * (gt.y1 - gb.y0)
            pos = [
                gb.x1 + gb.width * 0.02, gb.y0 + height, gb.width * cbarwidth,
                height
            ]
            cax = fig.add_axes(pos)
            # cb = plt.colorbar(mappable=ahandles[0], cax=cax)

            #               majorLocator = MultipleLocator(alphamap[0]['cbartickinterval'])
            #               cb.ax.yaxis.set_major_locator(majorLocator)

            #                if alphamap[0]['log']:
            #                   oldlabels = cb.ax.get_yticklabels()
            #                   oldlabels = np.array([float(x.get_text().replace('$','')) for x in oldlabels])
            #                   newlabels = 0.001*(10.0**oldlabels -1)*(amax[0] - amin[0]) + amin[0]
            #                   newlabels = [str(x)[:4] for x in newlabels]
            #                   cb.ax.set_yticklabels(newlabels)
            if alphamap[0]['log']:
                formatter = LogFormatter(10, labelOnlyBase=False)
                cb = plt.colorbar(ticks=[
                    0.1, 0.2, 0.5, 1, 5, 10, 20, 50, 100, 200, 500, 1000, 2000,
                    5000, 10000
                ],
                                  format=formatter,
                                  mappable=ahandles[0],
                                  cax=cax)
            else:
                cb = plt.colorbar(mappable=ahandles[0], cax=cax)

        else:
            height = (gt.y1 - gb.y0)

        # --- pos = [left, bottom, width, height]
        pos = [gb.x1 + gb.width * 0.02, gb.y0, gb.width * cbarwidth, height]
        cax = fig.add_axes(pos)
        cb = plt.colorbar(mappable=handles[0], cax=cax)

        # majorLocator = MultipleLocator(cbartickinterval)
        # cb.ax.yaxis.set_major_locator(majorLocator)

        if scale[0, 0] == "log":
            oldlabels = cb.ax.get_yticklabels()
            oldlabels = np.array(
                [float(x.get_text().replace('$', '')) for x in oldlabels])
            newlabels = 0.001 * (10.0**oldlabels - 1) * (vmax[0] -
                                                         vmin[0]) + vmin[0]
            newlabels = [str(x)[:4] for x in newlabels]
            cb.ax.set_yticklabels(newlabels)

    if cbarlabel is not None:

        if cbarlabelpad is None:
            if papercol == 1:
                cbarlabelpad = 0.15
            else:
                cbarlabelpad = 0.08

        fig.text(bigextents[2] + cbarlabelpad,
                 (bigextents[3] + bigextents[1]) / 2,
                 cbarlabel,
                 rotation='vertical',
                 horizontalalignment='right',
                 verticalalignment='center')

    if outname:
        plt.savefig(outname, bbox_inches='tight', pad_inches=0.1)
        plt.close(fig)

    else:
        plt.show()

    if latex:
        mpl.rcdefaults()
Beispiel #9
0
for category in CATEGORIES:
    print('{}: {} images'.format(category, len(os.listdir(os.path.join(train_dir, category)))))
print("-" * 27)

# create a dataframe for dataset including filename_path, catagory and id
train = []
for category_id, category in enumerate(CATEGORIES):
    for file in os.listdir(os.path.join(train_dir, category)):
        train.append(['../data/train/{}/{}'.format(category, file), category, category_id])
        
train = pd.DataFrame(train, columns=['file', 'category','category_id'])


# Show the sample images
fig = plt.figure(1, figsize=(NumCatergories, NumCatergories))
grid = ImageGrid(fig, 111, nrows_ncols=(NumCatergories, NumCatergories), axes_pad=0.05)

i = 0
for category_id, category in enumerate(CATEGORIES):
    for filepath in train[train['category'] == category]['file'].values[:NumCatergories]:
        ax = grid[i]
        img = read_img(filepath)
        ax.imshow(img)
        ax.axis('off')
        if i % NumCatergories == NumCatergories - 1:
            ax.text(250, 112, filepath.split('/')[3], verticalalignment='center')
        i += 1

print("display samples")
plt.pause(3)
plt.savefig('../results/sample_images.png', transparent=True)
Beispiel #10
0
    def test_calc_likelihoods_2():
        from numpy.testing import assert_array_almost_equal
        from numpy.testing import assert_almost_equal
        from myimage_analysis import calculate_nhi
        import matplotlib.pyplot as plt
        from matplotlib import cm
        from mpl_toolkits.axes_grid1 import ImageGrid

        av_image = np.array([[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 2, 1, 0],
                             [np.nan, 1, 1, 1, 0], [0, 0, 0, 0, 0]])

        #av_image_error = np.random.normal(0.1, size=av_image.shape)
        av_image_error = 0.1 * np.ones(av_image.shape)

        #nhi_image = av_image + np.random.normal(0.1, size=av_image.shape)
        hi_cube = np.zeros((5, 5, 5))

        # make inner channels correlated with av
        hi_cube[:, :, :] = np.array([
            [
                [1., 0., 0., 0., 0.],
                [np.nan, 0., 0., 0., 0.],
                [0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0.],
                [1., 0., 0., 0., 10.],
            ],
            [
                [0., 0., 0., 0., 0.],
                [0., 0., 2., 0., 0.],
                [0., 0., 4., 0., 0.],
                [0., 0., 2., 0., 0.],
                [0., 0., 0., 0., 0.],
            ],
            [
                [0., 0., 0., 0., 0.],
                [0., 0., 0., 2., 0.],
                [0., 0., 0., 2., 0.],
                [0., 0., 0., 2., np.nan],
                [0., 0., 0., 0., 0.],
            ],
            [
                [0., 0., 0., 0., 0.],
                [0., 2., 0., 0., 0.],
                [0., 2., 0., 0., 0.],
                [0., 2., 0., 0., 0.],
                [0., 0., 0., 0., 0.],
            ],
            [
                [0., 0., 0., 0., 0.],
                [0., 0., 0., 0., np.nan],
                [0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0.],
                [1., 0., 0., 0., 0.2],
            ],
        ])

        if 1:
            fig = plt.figure(figsize=(4, 4))
            imagegrid = ImageGrid(fig, (1, 1, 1),
                                  nrows_ncols=(1, 5),
                                  ngrids=5,
                                  cbar_mode="single",
                                  cbar_location='top',
                                  cbar_pad="2%",
                                  cbar_size='3%',
                                  axes_pad=0.1,
                                  aspect=True,
                                  label_mode='L',
                                  share_all=True)
            cmap = cm.get_cmap('Greys', 5)
            for i in xrange(5):
                im = imagegrid[i].imshow(
                    hi_cube[i, :, :],
                    origin='lower',
                    #aspect='auto',
                    cmap=cmap,
                    interpolation='none',
                    vmin=0,
                    vmax=4)
            #cb = imagegrid[i].cax.colorbar(im)
            cbar = imagegrid.cbar_axes[0].colorbar(im)
            #plt.title('HI Cube')
            plt.savefig('/usr/users/ezbc/Desktop/hi_cube.png')

        # make edge channels totally uncorrelated
        #hi_cube[(0, 4), :, :] = np.arange(0, 25).reshape(5,5)
        #hi_cube[(0, 4), :, :] = - np.ones((5,5))

        hi_vel_axis = np.arange(0, 5, 1)

        # add intercept
        intercept_answer = 0.9
        av_image = av_image + intercept_answer

        if 1:
            fig = plt.figure(figsize=(4, 4))
            params = {
                'figure.figsize': (1, 1),
                #'figure.titlesize': font_scale,
            }
            plt.rcParams.update(params)
            imagegrid = ImageGrid(fig, (1, 1, 1),
                                  nrows_ncols=(1, 1),
                                  ngrids=1,
                                  cbar_mode="single",
                                  cbar_location='top',
                                  cbar_pad="2%",
                                  cbar_size='3%',
                                  axes_pad=0.1,
                                  aspect=True,
                                  label_mode='L',
                                  share_all=True)
            cmap = cm.get_cmap('Greys', 5)
            im = imagegrid[0].imshow(
                av_image,
                origin='lower',
                #aspect='auto',
                cmap=cmap,
                interpolation='none',
                vmin=0,
                vmax=4)
            #cb = imagegrid[i].cax.colorbar(im)
            cbar = imagegrid.cbar_axes[0].colorbar(im)
            #plt.title('HI Cube')
            plt.savefig('/usr/users/ezbc/Desktop/av.png')

        width_grid = np.arange(0, 5, 1)
        dgr_grid = np.arange(0, 1, 0.1)
        intercept_grid = np.arange(-1, 1, 0.1)
        vel_center = 2

        results = \
            cloudpy._calc_likelihoods(
                              hi_cube=hi_cube / 1.832e-2,
                              hi_vel_axis=hi_vel_axis,
                              vel_center=vel_center,
                              av_image=av_image,
                              av_image_error=av_image_error,
                              width_grid=width_grid,
                              dgr_grid=dgr_grid,
                              intercept_grid=intercept_grid,
                              )

        dgr_answer = 1 / 2.0
        width_answer = 2
        width = results['width_max']
        dgr = results['dgr_max']
        intercept = results['intercept_max']
        print width

        if 0:
            width = width_answer
            intercept = intercept_answer
            dgr = dgr_answer

        vel_range = (vel_center - width / 2.0, vel_center + width / 2.0)

        nhi_image = calculate_nhi(cube=hi_cube,
                                  velocity_axis=hi_vel_axis,
                                  velocity_range=vel_range) / 1.823e-2
        if 1:
            fig = plt.figure(figsize=(4, 4))
            imagegrid = ImageGrid(fig, (1, 1, 1),
                                  nrows_ncols=(1, 1),
                                  ngrids=1,
                                  cbar_mode="single",
                                  cbar_location='top',
                                  cbar_pad="2%",
                                  cbar_size='3%',
                                  axes_pad=0.1,
                                  aspect=True,
                                  label_mode='L',
                                  share_all=True)
            cmap = cm.get_cmap('Greys', 5)
            im = imagegrid[0].imshow(
                nhi_image,
                origin='lower',
                #aspect='auto',
                cmap=cmap,
                interpolation='none',
                vmin=0,
                vmax=4)
            #cb = imagegrid[i].cax.colorbar(im)
            cbar = imagegrid.cbar_axes[0].colorbar(im)
            #plt.title('HI Cube')
            plt.savefig('/usr/users/ezbc/Desktop/nhi.png')
        if 1:
            fig = plt.figure(figsize=(4, 4))
            imagegrid = ImageGrid(fig, (1, 1, 1),
                                  nrows_ncols=(1, 1),
                                  ngrids=1,
                                  cbar_mode="single",
                                  cbar_location='top',
                                  cbar_pad="2%",
                                  cbar_size='3%',
                                  axes_pad=0.1,
                                  aspect=True,
                                  label_mode='L',
                                  share_all=True)
            cmap = cm.get_cmap('Greys', 5)
            im = imagegrid[0].imshow(
                nhi_image * dgr + intercept,
                origin='lower',
                #aspect='auto',
                cmap=cmap,
                interpolation='none',
                vmin=0,
                vmax=4)
            #cb = imagegrid[i].cax.colorbar(im)
            cbar = imagegrid.cbar_axes[0].colorbar(im)
            #plt.title('HI Cube')
            plt.savefig('/usr/users/ezbc/Desktop/av_model.png')

        print('residuals = ')
        print(av_image - (nhi_image * dgr + intercept))
        print('dgr', dgr)
        print('intercept', intercept)
        print('width', width)

        assert_almost_equal(results['intercept_max'], intercept_answer)
        assert_almost_equal(results['dgr_max'], dgr_answer)
        assert_almost_equal(results['width_max'], width_answer)
Beispiel #11
0
def plot_tsne_selection_grid(z_pos,
                             x_pos,
                             z_neg,
                             vmin,
                             vmax,
                             fig_path,
                             labels=None,
                             fig_size=(9, 9),
                             g_j=7,
                             s=.5,
                             suffix='png'):
    ncol = x_pos.shape[1]
    g_i = ncol // g_j if (ncol % g_j == 0) else ncol // g_j + 1
    if labels is None:
        labels = [str(a) for a in np.range(ncol)]

    fig = plt.figure(figsize=fig_size)
    fig.clf()
    grid = ImageGrid(
        fig,
        111,
        nrows_ncols=(g_i, g_j),
        ngrids=ncol,
        aspect=True,
        direction="row",
        axes_pad=(0.15, 0.5),
        add_all=True,
        label_mode="1",
        share_all=True,
        cbar_location="top",
        cbar_mode="each",
        cbar_size="8%",
        cbar_pad="5%",
    )
    for seq_index in range(ncol):
        ax = grid[seq_index]
        ax.text(0,
                .92,
                labels[seq_index],
                horizontalalignment='center',
                transform=ax.transAxes,
                size=20,
                weight='bold')
        a = x_pos[:, seq_index]
        ax.scatter(z_neg[:, 0],
                   z_neg[:, 1],
                   s=s,
                   marker='o',
                   c='lightgray',
                   alpha=0.5,
                   edgecolors='face')
        im = ax.scatter(z_pos[:, 0],
                        z_pos[:, 1],
                        s=s,
                        marker='o',
                        c=a,
                        cmap=cm.jet,
                        edgecolors='face',
                        vmin=vmin[seq_index],
                        vmax=vmax[seq_index])
        ax.cax.colorbar(im)
        clean_axis(ax)
        ax.grid(False)
    plt.savefig('.'.join([fig_path, suffix]), format=suffix)
    plt.clf()
    plt.close()
Beispiel #12
0
def plot_mass2light(
    ages=None,
    m2l=None,
    limits=None,
    savedir='./',
    filename=None,
    show=True,
    title='',
):
    ''' Plots

        amp_functions = tuple of functions
        '''

    # Import external modules
    import numpy as np
    import pyfits as pf
    import matplotlib.pyplot as plt
    import matplotlib
    from mpl_toolkits.axes_grid1 import ImageGrid

    # Set up plot aesthetics
    plt.clf()
    plt.rcdefaults()
    colormap = plt.cm.gist_ncar
    color_cycle = [colormap(i) for i in np.linspace(0, 0.9, len(ages))]
    fontScale = 12
    params = {  #'backend': .pdf',
        'axes.labelsize': fontScale,
        'axes.titlesize': fontScale,
        'text.fontsize': fontScale,
        'legend.fontsize': fontScale * 3 / 4,
        'xtick.labelsize': fontScale,
        'ytick.labelsize': fontScale,
        'font.weight': 500,
        'axes.labelweight': 500,
        'text.usetex': False,
        'figure.figsize': (6, 6),
        'axes.color_cycle': color_cycle  # colors of different plots
    }
    plt.rcParams.update(params)

    # Create figure
    fig = plt.figure()
    grid = ImageGrid(fig, (1, 1, 1),
                     nrows_ncols=(1, 1),
                     ngrids=1,
                     direction='row',
                     axes_pad=1,
                     aspect=False,
                     share_all=True,
                     label_mode='All')

    colors = ['k', 'b', 'g', 'r', 'c']
    linestyles = ['-', '--', '-.', '-', '-']
    letters = ['a', 'b']

    for i in range(1):
        ax = grid[i]

        ax.plot(ages,
                m2l,
                marker='s',
                color='k',
                markersize=3
                #label = 'Age = %s Gyr' % ages[i],
                )
        ax.axhline(y=4.83 / 4.64, xmin=-1, xmax=100, color='k')
        ax.annotate(r'$M_\odot / L_\odot$',
                    xy=(10, 4.83 / 4.64),
                    textcoords='offset points',
                    xytext=(2, 3))

        if limits is not None:
            ax.set_xlim(limits[0], limits[1])
            ax.set_ylim(limits[2], limits[3])

        # Adjust asthetics
        #ax.set_xscale('log')
        ax.set_xlabel(r'Age (Gyr)', )
        ax.set_ylabel(r'$M / L (M_\odot / L_\odot$)')
        ax.grid(True)
        ax.legend(loc='upper right')
        ax.set_title(title)

    if filename is not None:
        plt.savefig(savedir + filename, bbox_inches='tight', dpi=600)
    if show:
        fig.show()
Beispiel #13
0
def plot_fluxes(wavelengths=None,
                flux_list=None,
                ages=None,
                metals=None,
                limits=None,
                savedir='./',
                filename=None,
                show=True,
                title='',
                log_scale=(1, 1),
                normalized=True,
                attenuations=None,
                age_unit='Gyr',
                balmer_line=False):
    ''' Plots

        amp_functions = tuple of functions
        '''

    # Import external modules
    import numpy as np
    import pyfits as pf
    import matplotlib.pyplot as plt
    import matplotlib
    from mpl_toolkits.axes_grid1 import ImageGrid

    # Set up plot aesthetics
    plt.clf()
    plt.rcdefaults()
    colormap = plt.cm.gist_ncar
    color_cycle = [colormap(i) for i in np.linspace(0, 0.9, len(flux_list))]
    fontScale = 12
    params = {  #'backend': .pdf',
        'axes.labelsize': fontScale,
        'axes.titlesize': fontScale,
        'text.fontsize': fontScale,
        'legend.fontsize': fontScale * 3 / 4,
        'xtick.labelsize': fontScale,
        'ytick.labelsize': fontScale,
        'font.weight': 500,
        'axes.labelweight': 500,
        'text.usetex': False,
        'figure.figsize': (6, 6),
        'axes.color_cycle': color_cycle  # colors of different plots
    }
    plt.rcParams.update(params)

    # Create figure
    fig = plt.figure()
    grid = ImageGrid(fig, (1, 1, 1),
                     nrows_ncols=(1, 1),
                     ngrids=1,
                     direction='row',
                     axes_pad=1,
                     aspect=False,
                     share_all=True,
                     label_mode='All')

    colors = ['k', 'b', 'g', 'r', 'c']
    linestyles = ['-', '--', '-.', '-', '-']
    letters = ['a', 'b']

    for i in range(1):
        ax = grid[i]
        if balmer_line:
            lines = [6560, 4861, 4341, 4102, 3970, 3889, 3835, 3646]
            for line in lines:
                ax.axvline(x=line, ymin=0, ymax=1e10, color='k', alpha=0.5)

        for i, fluxes in enumerate(flux_list):
            if ages is not None and metals is None:
                ax.plot(wavelengths,
                        fluxes,
                        label='Age = %.1f %s' % (ages[i], age_unit))
            elif metals is not None and ages is None:
                ax.plot(
                    wavelengths,
                    fluxes,
                    label=r'Z = %s ' % metals[i],
                )
            elif ages is not None and metals is not None:
                ax.plot(wavelengths, fluxes,
                        label = 'Age = %.1f %s, Z = %s ' % \
                                (ages[i], age_unit, metals[i]),
                        )
            elif attenuations is not None:
                ax.plot(
                    wavelengths,
                    fluxes,
                    label=r'$A_V = $ %s' % attenuations[i],
                )

        if limits is not None:
            ax.set_xlim(limits[0], limits[1])
            ax.set_ylim(limits[2], limits[3])

        # Adjust asthetics
        if log_scale[0]:
            ax.set_xscale('log')
        if log_scale[1]:
            ax.set_yscale('log')
        ax.set_xlabel(r'$\lambda (\AA$)', )
        if normalized:
            ax.set_ylabel(r'$f_\lambda / f_{5500 \AA}$')
        else:
            ax.set_ylabel(r'$f_\lambda d\lambda$')

        ax.grid(True)
        ax.legend(loc='upper right')
        ax.set_title(title)

    if filename is not None:
        plt.savefig(savedir + filename, bbox_inches='tight', dpi=600)
    if show:
        fig.show()
Beispiel #14
0
def plot_mags(
    wavelengths=None,
    mag_list=None,
    ages=None,
    limits=None,
    savedir='./',
    filename=None,
    show=True,
    title='',
):
    ''' Plots

        amp_functions = tuple of functions
        '''

    # Import external modules
    import numpy as np
    import pyfits as pf
    import matplotlib.pyplot as plt
    import matplotlib
    from mpl_toolkits.axes_grid1 import ImageGrid

    # Set up plot aesthetics
    plt.clf()
    plt.rcdefaults()
    colormap = plt.cm.gist_ncar
    color_cycle = [colormap(i) for i in np.linspace(0, 0.9, len(mag_list))]
    fontScale = 12
    params = {  #'backend': .pdf',
        'axes.labelsize': fontScale,
        'axes.titlesize': fontScale,
        'text.fontsize': fontScale,
        'legend.fontsize': fontScale * 3 / 4,
        'xtick.labelsize': fontScale,
        'ytick.labelsize': fontScale,
        'font.weight': 500,
        'axes.labelweight': 500,
        'text.usetex': False,
        'figure.figsize': (6, 6),
        'axes.color_cycle': color_cycle  # colors of different plots
    }
    plt.rcParams.update(params)

    # Create figure
    fig = plt.figure()
    grid = ImageGrid(fig, (1, 1, 1),
                     nrows_ncols=(1, 1),
                     ngrids=1,
                     direction='row',
                     axes_pad=1,
                     aspect=False,
                     share_all=True,
                     label_mode='All')

    colors = ['k', 'b', 'g', 'r', 'c']
    linestyles = ['-', '--', '-.', '-', '-']
    letters = ['a', 'b']

    for i in range(1):
        ax = grid[i]

        for i, mags in enumerate(mag_list):
            ax.plot(
                wavelengths,
                mags,
                label='Age = %s Gyr' % ages[i],
            )

        # filters
        filters = ['U', 'B', 'V', 'R', 'I', 'J', 'H', 'K', 'NUV', 'FUV']
        filter_centers = [
            3630,
            4450,
            5510,
            6580,
            8060,
            12200,
            16300,
            21900,
            2274,
            1542,
        ]
        for j in range(len(filters)):
            ax.axvline(x=filter_centers[j], ymin=0, ymax=1e10, color='k')
            ax.annotate(filters[j],
                        xy=(filter_centers[j], -15),
                        textcoords='offset points',
                        xytext=(2, 3))

        if limits is not None:
            ax.set_xlim(limits[0], limits[1])
            ax.set_ylim(limits[2], limits[3])

        # Adjust asthetics
        ax.set_xscale('log')
        ax.set_xlabel(r'$\lambda$ ($\AA$)', )
        ax.set_ylabel(r'$M_{AB}\ d\lambda$ (mag / $\AA$)')
        #ax.grid(True)
        ax.legend(loc='upper right')
        ax.set_title(title)

    if filename is not None:
        plt.savefig(savedir + filename, bbox_inches='tight', dpi=600)
    if show:
        fig.show()
Beispiel #15
0
    col = 0
    row = 0

    while row < data.shape[1]:
        print('yielding (%d, %d)' % (col, row))
        yield (data[col, row])

        if col == 0:
            col = 1
        else:
            col = 0
            row += 1


# <codecell>
samp_im = data_gen()

fig = plt.figure(figsize=(10, 10))
grid = ImageGrid(
    fig,
    111,  # similar to subplot(111)
    nrows_ncols=(5, 2),
    axes_pad=0.1,
)

for ax, im in tqdm(zip(grid, samp_im), total=10):
    ax.imshow(im)

fig.suptitle('Sampled images and their nearest neighbor')
# plt.show()
plt.savefig('nn_fig.png')
            urcrnrlat=80,
            resolution='l',
            projection='cyl')

#Have consistent longitude definition
lonlon[lonlon < 0] = lonlon[lonlon < 0] + 360.

x, y = m(lonlon[:, :], latlat[:, :])

fig = plt.figure(figsize=(10, 10))
grid = ImageGrid(
    fig,
    111,  # as in plt.subplot(111)
    nrows_ncols=(2, 2),
    axes_pad=0.15,
    share_all=True,
    cbar_location="bottom",
    cbar_mode="single",
    cbar_size="3%",
    cbar_pad=0.15,
    label_mode="L",
)

seasons = ['DJF', 'MAM', 'JJA', 'SON']
nseas = len(seasons)

lonticks = np.arange(7) * 30.
latticks = np.arange(5) * 30. - 60

i = 0
for ax in grid:
def plot_power_spectrum(image,
                        title=None,
                        filename_prefix=None,
                        filename_suffix='.png',
                        show=False,
                        savedir='./'):
    '''
    Plots power spectrum derived from a fourier transform of the image.

    '''

    # import external modules
    import numpy as np
    from agpy import azimuthalAverage as radial_average
    from scipy import fftpack
    import matplotlib.pyplot as plt
    import matplotlib
    from mpl_toolkits.axes_grid1 import ImageGrid
    from matplotlib import cm

    if 0:
        plt.close()
        plt.clf()
        plt.imshow(image)
        plt.show()

    image[np.isnan(image)] = 1e10

    # Determine power spectrum
    # -------------------------------------------------------------------------
    # Take the fourier transform of the image.
    #F1 = fftpack.fft2(np.ma.array(image, mask=np.isnan(image)))
    F1 = fftpack.fft2(image)

    # Now shift the quadrants around so that low spatial frequencies are in
    # the center of the 2D fourier transformed image.
    F2 = fftpack.fftshift(F1)

    # Calculate a 2D power spectrum
    psd2D = np.abs(F2)**2

    if 0:
        plt.close()
        plt.clf()
        plt.imshow(psd2D)
        plt.show()

    power_spectrum = radial_average(psd2D, interpnan=True)

    # Write frequency in arcmin
    freq = fftpack.fftfreq(len(power_spectrum))
    freq *= 5.0

    # Simulate power spectrum for white noise
    noise_image = np.random.normal(scale=0.1, size=image.shape)

    F1 = fftpack.fft2(noise_image)

    # Now shift the quadrants around so that low spatial frequencies are in
    # the center of the 2D fourier transformed image.
    F2 = fftpack.fftshift(F1)

    # Calculate a 2D power spectrum
    psd2D_noise = np.abs(F2)**2

    power_spectrum_noise = radial_average(psd2D_noise, interpnan=True)

    # Plot power spectrum 1D
    # -------------------------------------------------------------------------
    # Set up plot aesthetics
    plt.clf()
    plt.rcdefaults()
    colormap = plt.cm.gist_ncar
    #color_cycle = [colormap(i) for i in np.linspace(0, 0.9, len(flux_list))]
    font_scale = 12
    params = {  #'backend': .pdf',
        'axes.labelsize': font_scale,
        'axes.titlesize': font_scale,
        'text.fontsize': font_scale,
        'legend.fontsize': font_scale * 3 / 4.0,
        'xtick.labelsize': font_scale,
        'ytick.labelsize': font_scale,
        'font.weight': 500,
        'axes.labelweight': 500,
        'text.usetex': False,
        'figure.figsize': (5, 5),
        #'axes.color_cycle': color_cycle # colors of different plots
    }
    plt.rcParams.update(params)

    # Create figure instance
    fig = plt.figure()

    nrows = 1
    ncols = 1
    ngrids = 1

    imagegrid = ImageGrid(
        fig,
        (1, 1, 1),
        nrows_ncols=(nrows, ncols),
        ngrids=ngrids,
        axes_pad=0.25,
        aspect=False,
        label_mode='L',
        share_all=True,
        #cbar_mode='single',
        cbar_pad=0.1,
        cbar_size=0.2,
    )

    ax = imagegrid[0]

    ax.plot(freq,
            power_spectrum / np.nanmax(power_spectrum),
            color='k',
            linestyle='-',
            linewidth=1.5,
            drawstyle='steps-mid',
            label='Data Residuals')

    ax.plot(freq,
            power_spectrum_noise / np.nanmax(power_spectrum_noise),
            color='r',
            linestyle='-',
            linewidth=0.4,
            drawstyle='steps-mid',
            label='White Noise Residuals')

    #ax.set_xscale('log')
    ax.legend(loc='best')
    #ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlabel('Spatial Frequency [1/arcmin]')
    ax.set_ylabel('Normalized Power Spectrum')
    ax.set_xlim(0, 0.4)

    if title is not None:
        fig.suptitle(title, fontsize=font_scale)
    if filename_prefix is not None:
        plt.savefig(savedir + filename_prefix + '_1D' + filename_suffix,
                    bbox_inches='tight')
    if show:
        plt.show()

    # Plot power spectrum image
    # -------------------------
    # Create figure instance
    fig = plt.figure()

    nrows = 1
    ncols = 1
    ngrids = 1

    imagegrid = ImageGrid(
        fig,
        (1, 1, 1),
        nrows_ncols=(nrows, ncols),
        ngrids=ngrids,
        axes_pad=0.25,
        aspect=False,
        label_mode='L',
        share_all=True,
        #cbar_mode='single',
        cbar_pad=0.1,
        cbar_size=0.2,
    )

    ax = imagegrid[0]

    extent = [
        -image.shape[0] / 2.0, +image.shape[0] / 2.0, -image.shape[1] / 2.0,
        +image.shape[1] / 2.0
    ]

    ax.imshow(psd2D,
              origin='lower',
              cmap=cm.gist_heat,
              norm=matplotlib.colors.LogNorm(),
              extent=extent)

    #ax.set_xscale('log')
    #ax.legend(loc='center right')
    #ax.set_yscale('log')
    ax.set_xlabel('Spatial Frequency in Right Ascension')
    ax.set_ylabel('Spatial Frequency in Declination')
    ax.set_xlim(-20, 20)
    ax.set_ylim(-20, 20)

    if title is not None:
        fig.suptitle(title, fontsize=font_scale)
    if filename_prefix is not None:
        plt.savefig(savedir + filename_prefix + '_2D' + filename_suffix,
                    bbox_inches='tight')
    if show:
        plt.show()
    # Mask with a disc
    R = R* disc((retina_shape[0],retina_shape[0]),
                (retina_shape[0]//2,retina_shape[0]//2),
                retina_shape[0]//2)

    # Take half-retina
    R = R[:,retina_shape[1]:]

    # Project to colliculus
    SC = R[P[...,0], P[...,1]]


    fig = plt.figure(figsize=(10,15), facecolor='w')
    ######################
    
    ax1, ax2 = ImageGrid(fig, 211, nrows_ncols=(1,2), axes_pad=0.5)
    polar_frame(ax1, legend=True)
    polar_imshow(ax1, R, vmin=0, vmax=5)
    logpolar_frame(ax2, legend=True)
    logpolar_imshow(ax2, SC, vmin=0, vmax=5)
    ax1.text(1.1, 1.1, u"a",
          ha="left", va="bottom", fontsize=20, fontweight='bold')
        #ax1.text(0., -1.28, u"0°                                                     90°",
        #      ha="left", va="bottom", fontsize=10)
    ################################
    
    ax1, ax2 = ImageGrid(fig, 212, nrows_ncols=(1,2), axes_pad=0.5)
    polar_frame(ax1, legend=True,reduced=True)
    '''
    zax = zoomed_inset_axes(ax1, 6, loc=1)
    polar_frame(zax, zoom=True)
Beispiel #19
0
def plotImageGrid(images,
                  nrows_ncols=None,
                  extent=None,
                  clim=None,
                  interpolation='none',
                  cmap='gray',
                  imScale=2.,
                  cbar=True,
                  titles=None,
                  titlecol=['r', 'y']):
    import matplotlib.pyplot as plt
    import matplotlib
    matplotlib.style.use('ggplot')
    from mpl_toolkits.axes_grid1 import ImageGrid

    def add_inner_title(ax, title, loc, size=None, **kwargs):
        from matplotlib.offsetbox import AnchoredText
        from matplotlib.patheffects import withStroke
        if size is None:
            size = dict(size=plt.rcParams['legend.fontsize'],
                        color=titlecol[0])
        at = AnchoredText(title,
                          loc=loc,
                          prop=size,
                          pad=0.,
                          borderpad=0.5,
                          frameon=False,
                          **kwargs)
        ax.add_artist(at)
        at.txt._text.set_path_effects(
            [withStroke(foreground=titlecol[1], linewidth=3)])
        return at

    if nrows_ncols is None:
        tmp = np.int(np.floor(np.sqrt(len(images))))
        nrows_ncols = (tmp, np.int(np.ceil(np.float(len(images)) / tmp)))
    if nrows_ncols[0] <= 0:
        nrows_ncols[0] = 1
    if nrows_ncols[1] <= 0:
        nrows_ncols[1] = 1
    size = (nrows_ncols[1] * imScale, nrows_ncols[0] * imScale)
    fig = plt.figure(1, size)
    igrid = ImageGrid(
        fig,
        111,  # similar to subplot(111)
        nrows_ncols=nrows_ncols,
        direction='row',  # creates 2x2 grid of axes
        axes_pad=0.1,  # pad between axes in inch.
        label_mode="L",  # share_all=True,
        cbar_location="right",
        cbar_mode="single",
        cbar_size='7%')
    extentWasNone = False
    for i in range(len(images)):
        ii = images[i]
        if hasattr(ii, 'computeImage'):
            ii = ii.computeImage()
        if hasattr(ii, 'getImage'):
            ii = ii.getImage()
        if hasattr(ii, 'getMaskedImage'):
            ii = ii.getMaskedImage().getImage()
        if hasattr(ii, 'getArray'):
            bbox = ii.getBBox()
            if extent is None:
                extentWasNone = True
                extent = (bbox.getBeginX(), bbox.getEndX(), bbox.getBeginY(),
                          bbox.getEndY())
            ii = ii.getArray()
        if cbar and clim is not None:
            ii = np.clip(ii, clim[0], clim[1])
        if extent is not None:
            ii = ii[extent[0]:extent[1], extent[2]:extent[3]]
        ii = zscale_image(ii)
        im = igrid[i].imshow(ii,
                             origin='lower',
                             interpolation=interpolation,
                             cmap=cmap,
                             extent=extent,
                             clim=clim)
        if cbar:
            igrid[i].cax.colorbar(im)
        if titles is not None:  # assume titles is an array or tuple of same length as images.
            t = add_inner_title(igrid[i], titles[i], loc=2)
            t.patch.set_ec("none")
            t.patch.set_alpha(0.5)
        if extentWasNone:
            extent = None
        extentWasNone = False
    return igrid
y_pred = cod2.argmax(1)
#y_pred = kmeans.predict(cod2)

print(np.unique(y_pred))
cat1 = 4
ind, = np.where(y_pred == cat1)
np.random.shuffle(ind)

ims = 100 * [None]
for j in range(100):
    ims[j] = (images[ind[j]])

plt.ion()

fig = plt.figure(figsize=(10, 10))
grid = ImageGrid(fig, 111, nrows_ncols=(8, 8), axes_pad=0.0, label_mode=None)
i = 0
for ax, im in zip(grid, ims):
    #ax.tick_params(labelbottom=False,labelleft=False)
    ax.imshow(im[:, :])
    ax.set_xticks([-1])
    ax.set_yticks([-1])
    rounded = [round(num1, 2) for num1 in cod2[ind[i]]]
    ax.text(0.05,
            0.9,
            str(rounded[0:3]),
            transform=ax.transAxes,
            fontsize=5,
            color=[1, 1, 1])
    ax.text(0.1,
            0.8,
def event_display(config):
    config.input_file = config.input_file[0]
    config.output_dir += ('' if config.output_dir.endswith('/') else '/')
    if not os.path.isdir(config.output_dir):
        os.mkdir(config.output_dir)
    print "Reading request from: " + str(config.input_file)
    print "output directory: " + str(config.output_dir)

    wl = open(config.input_file, 'r')
    lines = wl.readlines()
    for line in lines:
        splits = line.split()
        softmax = splits[0].strip()
        input_file = splits[1].strip()
        ev = int(splits[2].strip())

        print "now processing " + input_file + " at index " + str(ev)

        event_class = get_class(input_file)
        write_dir = config.output_dir + event_class + "_softmax" + str(
            softmax).split('.')[0] + '_' + str(softmax).split('.')[1] + "/"
        if not os.path.isdir(write_dir):
            os.mkdir(write_dir)

        norm = plt.Normalize()
        cm = matplotlib.cm.plasma
        cmaplist = [cm(i) for i in range(cm.N)]
        cm_cat_pmt_in_module = lsc.from_list('Custom cmap', cmaplist, cm.N)
        bounds_cat_pmt_in_module = np.linspace(0, 19, 20)
        norm_cat_pmt_in_module = matplotlib.colors.BoundaryNorm(
            bounds_cat_pmt_in_module, cm_cat_pmt_in_module.N)

        cm_cat_module_row = lsc.from_list('Custom cmap', cmaplist, cm.N)
        bounds_cat_module_row = np.linspace(0, 16, 17)
        norm_cat_module_row = matplotlib.colors.BoundaryNorm(
            bounds_cat_module_row, cm_cat_module_row.N)

        cm_cat_module_col = lsc.from_list('Custom cmap', cmaplist, cm.N)
        bounds_cat_module_col = np.linspace(0, 40, 41)
        norm_cat_module_col = matplotlib.colors.BoundaryNorm(
            bounds_cat_module_col, cm_cat_module_col.N)

        file = ROOT.TFile(input_file, "read")

        label = -1
        if "_gamma" in input_file:
            label = 0
        elif "_e" in input_file:
            label = 1
        elif "_mu" in input_file:
            label = 2
        elif "_pi0" in input_file:
            label = 3
        else:
            print "Unknown input file particle type"
            sys.exit()

        tree = file.Get("wcsimT")

        nevent = tree.GetEntries()

        print "number of entries in the tree: " + str(nevent)

        geotree = file.Get("wcsimGeoT")

        print "number of entries in the geometry tree: " + str(
            geotree.GetEntries())

        geotree.GetEntry(0)
        geo = geotree.wcsimrootgeom

        num_pmts = geo.GetWCNumPMT()

        np_pos_x_all_tubes = np.zeros((num_pmts))
        np_pos_y_all_tubes = np.zeros((num_pmts))
        np_pos_z_all_tubes = np.zeros((num_pmts))
        np_pmt_in_module_id_all_tubes = np.zeros((num_pmts))
        np_pmt_index_all_tubes = np.arange(num_pmts)
        np.random.shuffle(np_pmt_index_all_tubes)
        np_module_index_all_tubes = module_index(np_pmt_index_all_tubes)

        for i in range(len(np_pmt_index_all_tubes)):

            pmt_tube_in_module_id = np_pmt_index_all_tubes[i] % 19
            np_pmt_in_module_id_all_tubes[i] = pmt_tube_in_module_id
            pmt = geo.GetPMT(np_pmt_index_all_tubes[i])

            np_pos_x_all_tubes[i] = pmt.GetPosition(2)
            np_pos_y_all_tubes[i] = pmt.GetPosition(0)
            np_pos_z_all_tubes[i] = pmt.GetPosition(1)

        np_pos_r_all_tubes = np.hypot(np_pos_x_all_tubes, np_pos_y_all_tubes)
        r_max = np.amax(np_pos_r_all_tubes)

        np_wall_indices_ad_hoc = np.unique(np_module_index_all_tubes[np.where(
            (np_pos_z_all_tubes < 499.0) & (np_pos_z_all_tubes > -499.0))[0]])
        np_bottom_indices_ad_hoc = np.unique(
            np_module_index_all_tubes[np.where(
                (np_pos_z_all_tubes < -499.0))[0]])
        np_top_indices_ad_hoc = np.unique(np_module_index_all_tubes[np.where(
            (np_pos_z_all_tubes > 499.0))[0]])

        np_pos_phi_all_tubes = np.arctan2(np_pos_y_all_tubes,
                                          np_pos_x_all_tubes)
        np_pos_arc_all_tubes = r_max * np_pos_phi_all_tubes

        np_wall_indices = np.where(is_barrel(np_module_index_all_tubes))
        np_top_indices = np.where(is_top(np_module_index_all_tubes))
        np_bottom_indices = np.where(is_bottom(np_module_index_all_tubes))

        np_pmt_in_module_id_wall_tubes = np_pmt_in_module_id_all_tubes[
            np_wall_indices]
        np_pmt_in_module_id_top_tubes = np_pmt_in_module_id_all_tubes[
            np_top_indices]
        np_pmt_in_module_id_bottom_tubes = np_pmt_in_module_id_all_tubes[
            np_bottom_indices]

        np_pos_x_wall_tubes = np_pos_x_all_tubes[np_wall_indices]
        np_pos_y_wall_tubes = np_pos_y_all_tubes[np_wall_indices]
        np_pos_z_wall_tubes = np_pos_z_all_tubes[np_wall_indices]

        np_pos_x_top_tubes = np_pos_x_all_tubes[np_top_indices]
        np_pos_y_top_tubes = np_pos_y_all_tubes[np_top_indices]
        np_pos_z_top_tubes = np_pos_z_all_tubes[np_top_indices]

        np_pos_x_bottom_tubes = np_pos_x_all_tubes[np_bottom_indices]
        np_pos_y_bottom_tubes = np_pos_y_all_tubes[np_bottom_indices]
        np_pos_z_bottom_tubes = np_pos_z_all_tubes[np_bottom_indices]

        np_wall_row, np_wall_col = row_col(
            np_module_index_all_tubes[np_wall_indices])

        np_pos_phi_wall_tubes = np_pos_phi_all_tubes[np_wall_indices]
        np_pos_arc_wall_tubes = np_pos_arc_all_tubes[np_wall_indices]

        fig101 = plt.figure(num=101, clear=True)
        fig101.set_size_inches(10, 8)
        ax101 = fig101.add_subplot(111)
        pos_arc_z_disp_all_tubes = ax101.scatter(
            np_pos_arc_all_tubes,
            np_pos_z_all_tubes,
            c=np_pmt_in_module_id_all_tubes,
            s=5,
            cmap=cm_cat_pmt_in_module,
            norm=norm_cat_pmt_in_module,
            marker='.')
        ax101.set_xlabel('arc along the wall')
        ax101.set_ylabel('z')
        cb_pos_arc_z_disp_all_tubes = fig101.colorbar(pos_arc_z_disp_all_tubes,
                                                      ticks=range(20),
                                                      pad=0.1)
        cb_pos_arc_z_disp_all_tubes.set_label("pmt in module")
        fig101.savefig(write_dir + "pos_arc_z_disp_all_tubes.pdf")

        fig102 = plt.figure(num=102, clear=True)
        fig102.set_size_inches(10, 8)
        ax102 = fig102.add_subplot(111)
        pos_x_y_disp_all_tubes = ax102.scatter(np_pos_x_all_tubes,
                                               np_pos_y_all_tubes,
                                               c=np_pmt_in_module_id_all_tubes,
                                               s=5,
                                               cmap=cm_cat_pmt_in_module,
                                               norm=norm_cat_pmt_in_module,
                                               marker='.')
        ax102.set_xlabel('x')
        ax102.set_ylabel('y')
        cb_pos_x_y_disp_all_tubes = fig102.colorbar(pos_x_y_disp_all_tubes,
                                                    ticks=range(20),
                                                    pad=0.1)
        cb_pos_x_y_disp_all_tubes.set_label("pmt in module")
        fig102.savefig(write_dir + "pos_x_y_disp_all_tubes.pdf")

        fig103 = plt.figure(num=103, clear=True)
        fig103.set_size_inches(10, 8)
        ax103 = fig103.add_subplot(111)
        pos_arc_z_disp_wall_tubes = ax103.scatter(
            np_pos_arc_wall_tubes,
            np_pos_z_wall_tubes,
            c=np_pmt_in_module_id_wall_tubes,
            s=5,
            cmap=cm_cat_pmt_in_module,
            norm=norm_cat_pmt_in_module,
            marker='.')
        ax103.set_xlabel('arc along the wall')
        ax103.set_ylabel('z')
        cb_pos_arc_z_disp_wall_tubes = fig103.colorbar(
            pos_arc_z_disp_wall_tubes, ticks=range(20), pad=0.1)
        cb_pos_arc_z_disp_wall_tubes.set_label("pmt in module")
        fig103.savefig(write_dir + "pos_arc_z_disp_wall_tubes.pdf")

        fig104 = plt.figure(num=104, clear=True)
        fig104.set_size_inches(10, 8)
        ax104 = fig104.add_subplot(111)
        pos_arc_z_disp_wall_tubes = ax104.scatter(np_pos_arc_wall_tubes,
                                                  np_pos_z_wall_tubes,
                                                  c=np_wall_row,
                                                  s=5,
                                                  cmap=cm_cat_module_row,
                                                  norm=norm_cat_module_row,
                                                  marker='.')
        ax104.set_xlabel('arc along the wall')
        ax104.set_ylabel('z')
        cb_pos_arc_z_disp_wall_tubes = fig104.colorbar(
            pos_arc_z_disp_wall_tubes, ticks=range(16), pad=0.1)
        cb_pos_arc_z_disp_wall_tubes.set_label("wall module row")
        fig104.savefig(write_dir + "pos_arc_z_disp_wall_tubes_color_row.pdf")

        fig105 = plt.figure(num=105, clear=True)
        fig105.set_size_inches(10, 8)
        ax105 = fig105.add_subplot(111)
        pos_arc_z_disp_wall_tubes = ax105.scatter(np_pos_arc_wall_tubes,
                                                  np_pos_z_wall_tubes,
                                                  c=np_wall_col,
                                                  s=5,
                                                  cmap=cm_cat_module_col,
                                                  norm=norm_cat_module_col,
                                                  marker='.')
        ax105.set_xlabel('arc along the wall')
        ax105.set_ylabel('z')
        cb_pos_arc_z_disp_wall_tubes = fig105.colorbar(
            pos_arc_z_disp_wall_tubes, ticks=range(40), pad=0.1)
        cb_pos_arc_z_disp_wall_tubes.set_label("wall module column")
        fig105.savefig(write_dir + "pos_arc_z_disp_wall_tubes_color_col.pdf")

        fig106 = plt.figure(num=106, clear=True)
        fig106.set_size_inches(10, 8)
        ax106 = fig106.add_subplot(111)
        pos_x_y_disp_top_tubes = ax106.scatter(np_pos_x_top_tubes,
                                               np_pos_y_top_tubes,
                                               c=np_pmt_in_module_id_top_tubes,
                                               s=5,
                                               cmap=cm_cat_pmt_in_module,
                                               norm=norm_cat_pmt_in_module,
                                               marker='.')
        ax106.set_xlabel('x')
        ax106.set_ylabel('y')
        cb_pos_x_y_disp_top_tubes = fig106.colorbar(pos_x_y_disp_top_tubes,
                                                    ticks=range(20),
                                                    pad=0.1)
        cb_pos_x_y_disp_top_tubes.set_label("pmt in module")
        fig106.savefig(write_dir + "pos_x_y_disp_top_tubes.pdf")

        fig107 = plt.figure(num=107, clear=True)
        fig107.set_size_inches(10, 8)
        ax107 = fig107.add_subplot(111)
        pos_x_y_disp_bottom_tubes = ax107.scatter(
            np_pos_x_bottom_tubes,
            np_pos_y_bottom_tubes,
            c=np_pmt_in_module_id_bottom_tubes,
            s=5,
            cmap=cm_cat_pmt_in_module,
            norm=norm_cat_pmt_in_module,
            marker='.')
        ax107.set_xlabel('x')
        ax107.set_ylabel('y')
        cb_pos_x_y_disp_bottom_tubes = fig107.colorbar(
            pos_x_y_disp_bottom_tubes, ticks=range(20), pad=0.1)
        cb_pos_x_y_disp_bottom_tubes.set_label("pmt in module")
        fig107.savefig(write_dir + "pos_x_y_disp_bottom_tubes.pdf")

        Eth = {
            22: 0.786 * 2,
            11: 0.786,
            -11: 0.786,
            13: 158.7,
            -13: 158.7,
            111: 0.786 * 4
        }

        tree.GetEvent(ev)
        wcsimrootsuperevent = tree.wcsimrootevent
        print "number of sub events: " + str(
            wcsimrootsuperevent.GetNumberOfEvents())

        wcsimrootevent = wcsimrootsuperevent.GetTrigger(0)
        tracks = wcsimrootevent.GetTracks()
        energy = []
        position = []
        direction = []
        pid = []
        for i in range(wcsimrootevent.GetNtrack()):
            if tracks[i].GetParenttype() == 0 and tracks[i].GetFlag(
            ) == 0 and tracks[i].GetIpnu() in Eth.keys():
                pid.append(tracks[i].GetIpnu())
                position.append([
                    tracks[i].GetStart(0), tracks[i].GetStart(1),
                    tracks[i].GetStart(2)
                ])
                direction.append([
                    tracks[i].GetDir(0), tracks[i].GetDir(1),
                    tracks[i].GetDir(2)
                ])
                energy.append(tracks[i].GetE())

        biggestTrigger = 0
        biggestTriggerDigihits = 0
        for index in range(wcsimrootsuperevent.GetNumberOfEvents()):
            wcsimrootevent = wcsimrootsuperevent.GetTrigger(index)
            ncherenkovdigihits = wcsimrootevent.GetNcherenkovdigihits()
            if ncherenkovdigihits > biggestTriggerDigihits:
                biggestTriggerDigihits = ncherenkovdigihits
                biggestTrigger = index

        wcsimrootevent = wcsimrootsuperevent.GetTrigger(biggestTrigger)

        wcsimrootevent = wcsimrootsuperevent.GetTrigger(index)

        print "event date and number: " + str(
            wcsimrootevent.GetHeader().GetDate()) + " " + str(
                wcsimrootevent.GetHeader().GetEvtNum())

        ncherenkovhits = wcsimrootevent.GetNcherenkovhits()
        ncherenkovdigihits = wcsimrootevent.GetNcherenkovdigihits()

        print "Ncherenkovdigihits " + str(ncherenkovdigihits)

        if ncherenkovdigihits == 0:
            print "event, trigger has no hits " + str(ev) + " " + str(index)
            return

        np_pos_x = np.zeros((ncherenkovdigihits))
        np_pos_y = np.zeros((ncherenkovdigihits))
        np_pos_z = np.zeros((ncherenkovdigihits))

        np_dir_u = np.zeros((ncherenkovdigihits))
        np_dir_v = np.zeros((ncherenkovdigihits))
        np_dir_w = np.zeros((ncherenkovdigihits))

        np_cylloc = np.zeros((ncherenkovdigihits))
        np_cylloc = np_cylloc - 1000

        np_q = np.zeros((ncherenkovdigihits))
        np_t = np.zeros((ncherenkovdigihits))

        np_pmt_index = np.zeros((ncherenkovdigihits), dtype=np.int32)
        """
        The index starts at 1 and counts up continuously with no gaps
        Each 19 consecutive PMTs belong to one mPMT module, so (index-1)/19 is the module number.
        The index%19 gives the position in the module: 1-12 is the outer ring, 13-18 is the inner ring, 0 is the centre PMT
        The modules are then ordered as follows:
        It starts by going round the second highest ring around the barrel, then the third highest ring, fourth highest ring, all the way down to the lowest ring (i.e. skips the highest ring). Then does the bottom end-cap, row by row (the first row has 6 modules, the second row has 8, then 10, 10, 10, 10, 10, 10, 8, 6). Then the highest ring around the barrel that was skipped before, then the top end-cap, row by row. I'm not sure why it has this somewhat strange order...
        WTF: actually it is: 2, 6, 8 10, 10, 12 and down again in the caps
        """

        for i in range(ncherenkovdigihits):
            wcsimrootcherenkovdigihit = wcsimrootevent.GetCherenkovDigiHits(
            ).At(i)

            hit_q = wcsimrootcherenkovdigihit.GetQ()
            hit_t = wcsimrootcherenkovdigihit.GetT()
            hit_tube_id = wcsimrootcherenkovdigihit.GetTubeId() - 1

            np_pmt_index[i] = hit_tube_id

            #if i<10:
            #    print "q t id: "+str(hit_q)+" "+str(hit_t)+" "+str(hit_tube_id)+" "

            pmt = geo.GetPMT(hit_tube_id)

            #if i<10:
            #    print "pmt tube no: "+str(pmt.GetTubeNo()) #+" " +pmt.GetPMTName()
            #    print "pmt cyl loc: "+str(pmt.GetCylLoc())

            #np_cylloc[i]=pmt.GetCylLoc()

            np_pos_x[i] = pmt.GetPosition(2)
            np_pos_y[i] = pmt.GetPosition(0)
            np_pos_z[i] = pmt.GetPosition(1)

            np_dir_u[i] = pmt.GetOrientation(2)
            np_dir_v[i] = pmt.GetOrientation(0)
            np_dir_w[i] = pmt.GetOrientation(1)

            np_q[i] = hit_q
            np_t[i] = hit_t

        np_module_index = module_index(np_pmt_index)
        np_pmt_in_module_id = pmt_in_module_id(np_pmt_index)

        np_wall_indices = np.where(is_barrel(np_module_index))
        np_top_indices = np.where(is_top(np_module_index))
        np_bottom_indices = np.where(is_bottom(np_module_index))

        np_pos_r = np.hypot(np_pos_x, np_pos_y)
        np_pos_phi = np.arctan2(np_pos_y, np_pos_x)
        np_pos_arc = r_max * np_pos_phi
        np_pos_arc_wall = np_pos_arc[np_wall_indices]

        np_pos_x_top = np_pos_x[np_top_indices]
        np_pos_y_top = np_pos_y[np_top_indices]
        np_pos_z_top = np_pos_z[np_top_indices]

        np_pos_x_bottom = np_pos_x[np_bottom_indices]
        np_pos_y_bottom = np_pos_y[np_bottom_indices]
        np_pos_z_bottom = np_pos_z[np_bottom_indices]

        np_pos_x_wall = np_pos_x[np_wall_indices]
        np_pos_y_wall = np_pos_y[np_wall_indices]
        np_pos_z_wall = np_pos_z[np_wall_indices]

        np_q_top = np_q[np_top_indices]
        np_t_top = np_t[np_top_indices]

        np_q_bottom = np_q[np_bottom_indices]
        np_t_bottom = np_t[np_bottom_indices]

        np_q_wall = np_q[np_wall_indices]
        np_t_wall = np_t[np_wall_indices]

        np_wall_row, np_wall_col = row_col(np_module_index[np_wall_indices])
        np_pmt_in_module_id_wall = np_pmt_in_module_id[np_wall_indices]

        np_wall_data_rect = np.zeros((16, 40, 38))
        np_wall_data_rect[np_wall_row, np_wall_col,
                          np_pmt_in_module_id_wall] = np_q_wall
        np_wall_data_rect[np_wall_row, np_wall_col,
                          np_pmt_in_module_id_wall + 19] = np_t_wall

        np_wall_data_rect_ev = np.expand_dims(np_wall_data_rect, axis=0)

        np_wall_q_max_module = np.amax(np_wall_data_rect[:, :, 0:19], axis=-1)
        np_wall_q_sum_module = np.sum(np_wall_data_rect[:, :, 0:19], axis=-1)

        max_q = np.amax(np_q)
        np_scaled_q = 500 * np_q / max_q

        np_dir_u_scaled = np_dir_u * np_scaled_q
        np_dir_v_scaled = np_dir_v * np_scaled_q
        np_dir_w_scaled = np_dir_w * np_scaled_q

        fig1 = plt.figure(num=1, clear=True)
        fig1.set_size_inches(10, 8)
        ax1 = fig1.add_subplot(111, projection='3d', azim=35, elev=20)
        ev_disp = ax1.scatter(np_pos_x,
                              np_pos_y,
                              np_pos_z,
                              c=np_q,
                              s=2,
                              alpha=0.4,
                              cmap=cm,
                              marker='.')
        ax1.set_xlabel('x')
        ax1.set_ylabel('y')
        ax1.set_zlabel('z')
        cb_ev_disp = fig1.colorbar(ev_disp, pad=0.03)
        cb_ev_disp.set_label("charge")
        fig1.savefig(write_dir + "ev_disp_ev_{}_trig_{}.pdf".format(ev, index))

        fig2 = plt.figure(num=2, clear=True)
        fig2.set_size_inches(10, 8)
        ax2 = fig2.add_subplot(111, projection='3d', azim=35, elev=20)
        colors = plt.cm.spring(norm(np_t))
        ev_disp_q = ax2.quiver(np_pos_x,
                               np_pos_y,
                               np_pos_z,
                               np_dir_u_scaled,
                               np_dir_v_scaled,
                               np_dir_w_scaled,
                               colors=colors,
                               alpha=0.4,
                               cmap=cm)
        ax2.set_xlabel('x')
        ax2.set_ylabel('y')
        ax2.set_zlabel('z')
        sm = matplotlib.cm.ScalarMappable(cmap=cm, norm=norm)
        sm.set_array([])
        cb_ev_disp_2 = fig2.colorbar(sm, pad=0.03)
        cb_ev_disp_2.set_label("time")
        fig2.savefig(write_dir +
                     "ev_disp_quiver_ev_{}_trig_{}.pdf".format(ev, index))

        fig3 = plt.figure(num=3, clear=True)
        fig3.set_size_inches(10, 8)
        ax3 = fig3.add_subplot(111)
        ev_disp_wall = ax3.scatter(np_pos_arc_wall,
                                   np_pos_z_wall,
                                   c=np_q_wall,
                                   s=2,
                                   cmap=cm,
                                   marker='.')
        ax3.set_xlabel('arc along the wall')
        ax3.set_ylabel('z')
        cb_ev_disp_wall = fig3.colorbar(ev_disp_wall, pad=0.1)
        cb_ev_disp_wall.set_label("charge")
        fig3.savefig(write_dir +
                     "ev_disp_wall_ev_{}_trig_{}.pdf".format(ev, index))

        fig4 = plt.figure(num=4, clear=True)
        fig4.set_size_inches(10, 8)
        ax4 = fig4.add_subplot(111)
        ev_disp_top = ax4.scatter(np_pos_x_top,
                                  np_pos_y_top,
                                  c=np_q_top,
                                  s=2,
                                  cmap=cm,
                                  marker='.')
        ax4.set_xlabel('x')
        ax4.set_ylabel('y')
        cb_ev_disp_top = fig4.colorbar(ev_disp_top, pad=0.1)
        cb_ev_disp_top.set_label("charge")
        fig4.savefig(write_dir +
                     "ev_disp_top_ev_{}_trig_{}.pdf".format(ev, index))

        fig5 = plt.figure(num=5, clear=True)
        fig5.set_size_inches(10, 8)
        ax5 = fig5.add_subplot(111)
        ev_disp_bottom = ax5.scatter(np_pos_x_bottom,
                                     np_pos_y_bottom,
                                     c=np_q_bottom,
                                     s=2,
                                     cmap=cm,
                                     marker='.')
        ax5.set_xlabel('x')
        ax5.set_ylabel('y')
        cb_ev_disp_bottom = fig5.colorbar(ev_disp_bottom, pad=0.1)
        cb_ev_disp_bottom.set_label("charge")
        fig5.savefig(write_dir +
                     "ev_disp_bottom_ev_{}_trig_{}.pdf".format(ev, index))

        fig6 = plt.figure(num=6, clear=True)
        fig6.set_size_inches(10, 4)
        ax6 = fig6.add_subplot(111)
        q_sum_disp = ax6.imshow(np.flip(np_wall_q_sum_module, axis=0), cmap=cm)
        ax6.set_xlabel('arc index')
        ax6.set_ylabel('z index')
        cb_q_sum_disp = fig6.colorbar(q_sum_disp, pad=0.1)
        cb_q_sum_disp.set_label("total charge in module")
        fig6.savefig(write_dir +
                     "q_sum_disp_ev_{}_trig_{}.pdf".format(ev, index))

        fig7 = plt.figure(num=7, clear=True)
        fig7.set_size_inches(10, 4)
        ax7 = fig7.add_subplot(111)
        q_max_disp = ax7.imshow(np.flip(np_wall_q_max_module, axis=0), cmap=cm)
        ax7.set_xlabel('arc index')
        ax7.set_ylabel('z index')
        cb_q_max_disp = fig7.colorbar(q_max_disp, pad=0.1)
        cb_q_max_disp.set_label("maximum charge in module")
        fig7.savefig(write_dir +
                     "q_max_disp_ev_{}_trig_{}.pdf".format(ev, index))

        fig8 = plt.figure(num=8, clear=True)
        fig8.set_size_inches(10, 8)
        ax8 = fig8.add_subplot(111)
        plt.hist(np_q, 50, density=True, facecolor='blue', alpha=0.75)
        ax8.set_xlabel('charge')
        ax8.set_ylabel("PMT's above threshold")
        fig8.savefig(write_dir +
                     "q_pmt_disp_ev_{}_trig_{}.pdf".format(ev, index))

        fig9 = plt.figure(num=9, clear=True)
        fig9.set_size_inches(10, 8)
        ax9 = fig9.add_subplot(111)
        plt.hist(np_t, 50, density=True, facecolor='blue', alpha=0.75)
        ax9.set_xlabel('time')
        ax9.set_ylabel("PMT's above threshold")
        fig9.savefig(write_dir +
                     "t_pmt_disp_ev_{}_trig_{}.pdf".format(ev, index))

        fig10 = plt.figure(num=10, clear=True)
        fig10.set_size_inches(15, 5)
        grid_q = ImageGrid(
            fig10,
            111,
            nrows_ncols=(4, 5),
            axes_pad=0.0,
            share_all=True,
            label_mode="L",
            cbar_location="top",
            cbar_mode="single",
        )
        for i in range(19):
            q_disp = grid_q[i].imshow(np.flip(np_wall_data_rect[:, :, i],
                                              axis=0),
                                      cmap=cm)
            q_disp = grid_q[19].imshow(np.flip(np_wall_q_max_module, axis=0),
                                       cmap=cm)
            grid_q.cbar_axes[0].colorbar(q_disp)

        fig10.savefig(write_dir +
                      "q_disp_grid_ev_{}_trig_{}.pdf".format(ev, index))

        fig11 = plt.figure(num=11, clear=True)
        fig11.set_size_inches(15, 5)
        grid_t = ImageGrid(
            fig11,
            111,
            nrows_ncols=(4, 5),
            axes_pad=0.0,
            share_all=True,
            label_mode="L",
            cbar_location="top",
            cbar_mode="single",
        )
        for i in range(19):
            t_disp = grid_t[i].imshow(np.flip(np_wall_data_rect[:, :, i + 19],
                                              axis=0),
                                      cmap=cm)

        fig11.savefig(write_dir +
                      "t_disp_grid_ev_{}_trig_{}.pdf".format(ev, index))

    wl.close()
def plot_av_vs_nhi_grid(nhi_images,
                        av_images,
                        nhi_error_images=None,
                        av_error_images=None,
                        limits=None,
                        savedir='./',
                        filename=None,
                        show=False,
                        scale=['linear', 'linear'],
                        returnimage=False,
                        hess_binsize=None,
                        title='',
                        plot_type='hexbin',
                        color_scale='linear'):

    # Import external modules
    import numpy as np
    import math
    import pyfits as pf
    import matplotlib.pyplot as plt
    import matplotlib
    from mpl_toolkits.axes_grid1 import ImageGrid

    n = int(np.ceil(len(av_images)**0.5))
    if n**2 - n > len(av_images):
        nrows = n - 1
        ncols = n
        y_scaling = 1.0 - 1.0 / n
    else:
        nrows, ncols = n, n
        y_scaling = 1.0

    # Set up plot aesthetics
    plt.clf()
    plt.rcdefaults()
    colormap = plt.cm.gist_ncar
    #color_cycle = [colormap(i) for i in np.linspace(0, 0.9, len(flux_list))]
    font_scale = 12
    params = {  #'backend': .pdf',
        'axes.labelsize': font_scale,
        'axes.titlesize': font_scale,
        'text.fontsize': font_scale,
        'legend.fontsize': font_scale * 3 / 4.0,
        'xtick.labelsize': font_scale,
        'ytick.labelsize': font_scale,
        'font.weight': 500,
        'axes.labelweight': 500,
        'text.usetex': True,
        'figure.figsize': (8, 8 * y_scaling),
        #'axes.color_cycle': color_cycle # colors of different plots
    }
    plt.rcParams.update(params)

    # Create figure instance
    fig = plt.figure()

    imagegrid = ImageGrid(fig, (1, 1, 1),
                          nrows_ncols=(nrows, ncols),
                          ngrids=len(av_images),
                          axes_pad=0.25,
                          aspect=False,
                          label_mode='L',
                          share_all=True)

    # Cycle through lists
    for i in xrange(len(av_images)):
        av = av_images[i]
        nhi = nhi_images[i]
        av_error = av_error_images[i]
        nhi_error = nhi_error_images[i]
        #av_fit = av_fits[i]
        #nhi_fit = nhi_fits[i]

        # Drop the NaNs from the images
        if type(av_error) is float:
            indices = np.where((av == av) &\
                               (nhi == nhi)&\
                               (nhi > 0) &\
                               (av > 0))

        if type(av_error) is np.ndarray or \
                type(av_error) is np.ma.core.MaskedArray or \
                type(nhi_error) is np.ndarray or \
                type(nhi_error) is np.ma.core.MaskedArray:
            indices = np.where((av == av) &\
                               (nhi == nhi) &\
                               (nhi_error == nhi_error) &\
                               (av_error == av_error) &\
                               (nhi > 0) &\
                               (av > 0))

        av_nonans = av[indices]
        nhi_nonans = nhi[indices]

        if type(av_error) is np.ndarray:
            av_error_nonans = av_error[indices]
        else:
            av_error_nonans = np.array(av_error[indices])

        if type(nhi_error) is np.ndarray or \
                type(nhi_error) is np.ma.core.MaskedArray:
            nhi_error_nonans = nhi_error[indices]
        else:
            nhi_error_nonans = nhi_error * \
                    np.ones(nhi[indices].shape)

            # Create plot
        ax = imagegrid[i]

        image = ax.errorbar(nhi_nonans.ravel(),
                            av_nonans.ravel(),
                            xerr=(nhi_error_nonans.ravel()),
                            yerr=(av_error_nonans.ravel()),
                            alpha=0.3,
                            color='k',
                            marker='^',
                            ecolor='k',
                            linestyle='none',
                            markersize=4)

        #if av_fit is not None:
        #    ax.plot(nhi_fit, av_fit,
        #            color = 'r')

        # Annotations
        anno_xpos = 0.95
        '''
        if phi_cnm_list is not None and Z_list is not None:
            if phi_cnm_error_list is None and Z_error_list is not None:
                ax.annotate(r'$\phi_{\rm CNM}$ = {0:.2f}\n'.format(phi_cnm) + \
                            r'Z = {0:.2f} Z$_\odot$'.format(Z),
                        xytext=(anno_xpos, 0.05),
                        xy=(anno_xpos, 0.05),
                        textcoords='axes fraction',
                        xycoords='axes fraction',
                        color='k',
                        bbox=dict(boxstyle='round',
                                  facecolor='w',
                                  alpha=0.5),
                        horizontalalignment='right',
                        verticalalignment='bottom',
                        )
            else:
                ax.annotate(r'\noindent$\phi_{\rm CNM}$ =' + \
                            r' %.2f' % (phi_cnm) + \
                            r'$^{+%.2f}_{-%.2f}$ \\' % (phi_cnm_error[0],
                                                     phi_cnm_error[1]) + \
                            r'Z = %.2f' % (Z) + \
                            r'$^{+%.2f}_{-%.2f}$ Z$_\odot$' % (Z_error[0],
                                                               Z_error[1]) + \
                            r'',
                        xytext=(anno_xpos, 0.05),
                        xy=(anno_xpos, 0.05),
                        textcoords='axes fraction',
                        xycoords='axes fraction',
                        size=font_scale*3/4.0,
                        color='k',
                        bbox=dict(boxstyle='round',
                                  facecolor='w',
                                  alpha=1),
                        horizontalalignment='right',
                        verticalalignment='bottom',
                        )
        '''

        ax.set_xscale(scale[0], nonposx='clip')
        ax.set_yscale(scale[1], nonposy='clip')

        if limits is not None:
            ax.set_xlim(limits[0], limits[1])
            ax.set_ylim(limits[2], limits[3])

        # Adjust asthetics
        ax.set_xlabel(r'$N(HI)$ (10$^{20}$ cm$^{-2}$)')
        ax.set_ylabel(r'A$_{\rm V}$ (mag)')
        ax.set_title(title)
        ax.grid(True)

    if title is not None:
        fig.suptitle(title, fontsize=font_scale * 1.5)
    if filename is not None:
        plt.savefig(savedir + filename)  #, bbox_inches='tight')
    if show:
        fig.show()
def attention_epoch_plot(net,
                         folder_name,
                         source_images,
                         logged=False,
                         width=5,
                         device=torch.device('cpu'),
                         layer_name_base='attention',
                         layer_no=2,
                         cmap_name='magma',
                         figsize=(100, 100)):
    """
    Function for plotting clean grid of attention maps as they
    develop throughout the learning stages.
    Args:
        The attention map data,
        original images of sources
        number of unique sources,
        if you want your image logged,
        number of output attentions desired (sampled evenly accross available space)
        epoch labels of when the images were extracted
    Out:
        plt of images concatenated in correct fashion
    """

    # cmap_name and RGB potential
    if cmap_name == 'RGB':
        mean_ = False
        cmap_name = 'magma'

    # Generate attention maps for each available Epoch
    attention_maps_temp, og_attention_maps, epoch_labels = AttentionImagesByEpoch(
        source_images,
        folder_name,
        net,
        epoch=2000,
        device=device,
        layer_name_base=layer_name_base,
        layer_no=layer_no,
        mean=mean_)

    # Extract terms to be used in plotting
    sample_number = source_images.shape[0]
    no_saved_attentions_epochs = np.asarray(
        attention_maps_temp).shape[0] // sample_number
    attentions = np.asarray(attention_maps_temp)
    imgs = []
    labels = []
    width_array = range(no_saved_attentions_epochs)

    if width <= no_saved_attentions_epochs:
        width_array = np.linspace(0,
                                  no_saved_attentions_epochs - 1,
                                  num=width,
                                  dtype=np.int32)
    else:
        width = no_saved_attentions_epochs

    # Prepare the selection of images in the correct order as to be plotted reasonably (and prepare epoch labels)
    for j in range(sample_number):
        if logged:
            imgs.append(np.exp(source_images[j].squeeze()))
        else:
            imgs.append(source_images[j].squeeze())
        for i in width_array:
            #print(sample_number,i,j)
            imgs.append(attention_maps_temp[sample_number * i + j])
            try:
                labels[width - 1]
            except:
                labels.append(epoch_labels[sample_number * i])

    # Define the plot of the grid of images
    fig = plt.figure(figsize=figsize)
    grid = ImageGrid(
        fig,
        111,
        nrows_ncols=(sample_number, width + 1),
        axes_pad=0.02,  # pad between axes in inch.
    )
    for idx, (ax, im) in enumerate(zip(grid, imgs)):
        # Transpose for RGB image
        if im.shape[0] == 3:
            im = im.transpose(1, 2, 0)
        # Plot image
        if logged:
            ax.imshow(np.log(im), cmap=cmap_name)
        else:
            ax.imshow(im, cmap=cmap_name)
        # Plot contour if image is source image
        if idx % (width + 1) == 0:
            ax.contour(im, 1, cmap='cool', alpha=0.5)
        ax.axis('off')
    print(
        f'Source images followed by their respective averaged attention maps at epochs:\n{labels}'
    )
    plt.show()
def plot_av_vs_nhi(nhi_image,
                   av_image,
                   limits=None,
                   savedir='./',
                   filename=None,
                   show=False,
                   scale=['linear', 'linear'],
                   returnimage=False,
                   hess_binsize=None,
                   title='',
                   plot_type='hexbin',
                   color_scale='linear'):

    # Import external modules
    import numpy as np
    import math
    import pyfits as pf
    import matplotlib.pyplot as plt
    import matplotlib
    from matplotlib import cm
    from mpl_toolkits.axes_grid1 import ImageGrid

    # Drop the NaNs from the images
    indices = np.where((nhi_image == nhi_image) &\
                       (av_image == av_image) &\
                       (nhi_image > 0) &\
                       (av_image > 0))

    try:
        nhi_image_nonans = nhi_image[indices]
        av_image_nonans = av_image[indices]

        if type(av_image_error) is float:
            av_image_error_nonans = sd_image_error * \
                    np.ones(av_image[indices].shape)
        else:
            av_image_error_nonans = sd_image_error[indices]

        if type(nhi_image_error) is np.ndarray:
            nhi_image_error_nonans = nhi_image_error[indices]
        else:
            nhi_image_error_nonans = nhi_image_error * \
                    np.ones(nhi_image[indices].shape)
    except NameError:
        no_errors = True

    # Create figure
    # Set up plot aesthetics
    plt.clf()
    plt.rcdefaults()
    colormap = plt.cm.gist_ncar
    #color_cycle = [colormap(i) for i in np.linspace(0, 0.9, len(flux_list))]
    fig_size = (4, 4)
    font_scale = 10
    params = {  #'backend': .pdf',
        'axes.labelsize': font_scale,
        'axes.titlesize': font_scale,
        'text.fontsize': font_scale,
        'legend.fontsize': font_scale * 3 / 4,
        'xtick.labelsize': font_scale,
        'ytick.labelsize': font_scale,
        'font.weight': 500,
        'axes.labelweight': 500,
        'text.usetex': False,
        'figure.figsize': fig_size,
        #'axes.color_cycle': color_cycle # colors of different plots
    }
    plt.rcParams.update(params)

    if plot_type == 'scatter':
        cbar_mode = 'None'
    else:
        cbar_mode = 'single'

    # Create figure
    plt.clf()
    fig = plt.figure()
    imagegrid = ImageGrid(fig, (1, 1, 1),
                          nrows_ncols=(1, 1),
                          ngrids=1,
                          axes_pad=0.25,
                          aspect=False,
                          label_mode='L',
                          share_all=True,
                          cbar_mode=cbar_mode,
                          cbar_pad=0.1,
                          cbar_size=0.2)

    ax = imagegrid[0]

    if plot_type is 'hexbin':
        if color_scale == 'linear':
            image = ax.hexbin(nhi_image_nonans.ravel(),
                              av_image_nonans.ravel(),
                              mincnt=1,
                              xscale=scale[0],
                              yscale=scale[1],
                              cmap=cm.gist_stern)
            cb = ax.cax.colorbar(image, )
            # Write label to colorbar
            cb.set_label_text('Bin Counts', )
        elif color_scale == 'log':
            image = ax.hexbin(nhi_image_nonans.ravel(),
                              av_image_nonans.ravel(),
                              norm=matplotlib.colors.LogNorm(),
                              mincnt=1,
                              xscale=scale[0],
                              yscale=scale[1],
                              gridsize=(100, 200),
                              cmap=cm.gist_stern)
            cb = ax.cax.colorbar(image, )
            # Write label to colorbar
            cb.set_label_text('Bin Counts', )
        # Adjust color bar of density plot
        #cb = image.colorbar(image)
        #cb.set_label('Bin Counts')
    elif plot_type is 'scatter':
        image = ax.scatter(nhi_image_nonans.ravel(),
                           av_image_nonans.ravel(),
                           alpha=0.3,
                           color='k')
        ax.set_xscale(scale[0])
        ax.set_yscale(scale[1])

    if limits is not None:
        ax.set_xlim(limits[0], limits[1])
        ax.set_ylim(limits[2], limits[3])

    # Adjust asthetics
    ax.set_xlabel(r'$N(HI)$ (10$^{20}$ cm$^{-2}$)')
    ax.set_ylabel(r'A$_{\rm V}$ (mag)')
    ax.set_title(title)
    ax.grid(True)

    if filename is not None:
        plt.savefig(savedir + filename, bbox_inches='tight', dpi=600)
    if show:
        fig.show()
    if returnimage:
        return correlations_image
Beispiel #25
0
def getFigAx(subplot, name=None, title=None, figsize=None,  mpl=None, margins=None,
             sharex=None, sharey=None, AxesGrid=False, ngrids=None, direction='row',
             axes_pad = None, add_all=True, share_all=None, aspect=False,
             label_mode='L', cbar_mode=None, cbar_location='right',
             cbar_pad=None, cbar_size='5%', axes_class=None, lreduce=True): 
  # configure matplotlib
  warn('Deprecated function: use Figure or Axes class methods.')
  if mpl is None: import matplotlib as mpl
  elif isinstance(mpl,dict): mpl = loadMPL(**mpl) # there can be a mplrc, but also others
  elif not isinstance(mpl,ModuleType): raise TypeError
  from plotting.figure import MyFigure # prevent circular reference
  # figure out subplots
  if isinstance(subplot,(np.integer,int)):
    if subplot == 1: subplot = (1,1)
    elif subplot == 2: subplot = (1,2)
    elif subplot == 3: subplot = (1,3)
    elif subplot == 4: subplot = (2,2)
    elif subplot == 6: subplot = (2,3)
    elif subplot == 9: subplot = (3,3)
    else: raise NotImplementedError
  elif not (isinstance(subplot,(tuple,list)) and len(subplot) == 2) and all(isInt(subplot)): raise TypeError    
  # create figure
  if figsize is None: 
    if subplot == (1,1): figsize = (3.75,3.75)
    elif subplot == (1,2) or subplot == (1,3): figsize = (6.25,3.75)
    elif subplot == (2,1) or subplot == (3,1): figsize = (3.75,6.25)
    else: figsize = (6.25,6.25)
    #elif subplot == (2,2) or subplot == (3,3): figsize = (6.25,6.25)
    #else: raise NotImplementedError
  # figure out margins
  if margins is None:
    # N.B.: the rectangle definition is presumably left, bottom, width, height
    if subplot == (1,1): margins = (0.09,0.09,0.88,0.88)
    elif subplot == (1,2) or subplot == (1,3): margins = (0.06,0.1,0.92,0.87)
    elif subplot == (2,1) or subplot == (3,1): margins = (0.09,0.11,0.88,0.82)
    elif subplot == (2,2) or subplot == (3,3): margins = (0.055,0.055,0.925,0.925)
    else: margins = (0.09,0.11,0.88,0.82)
    #elif subplot == (2,2) or subplot == (3,3): margins = (0.09,0.11,0.88,0.82)
    #else: raise NotImplementedError    
    if title is not None: margins = margins[:3]+(margins[3]-0.03,) # make room for title
  if AxesGrid:
    if share_all is None: share_all = True
    if axes_pad is None: axes_pad = 0.05
    # create axes using the Axes Grid package
    fig = mpl.pylab.figure(facecolor='white', figsize=figsize, FigureClass=MyFigure)
    if axes_class is None:
      from plotting.axes import MyLocatableAxes  
      axes_class=(MyLocatableAxes,{})
    from mpl_toolkits.axes_grid1 import ImageGrid
    # AxesGrid: http://matplotlib.org/mpl_toolkits/axes_grid/users/overview.html
    grid = ImageGrid(fig, margins, nrows_ncols = subplot, ngrids=ngrids, direction=direction, 
                     axes_pad=axes_pad, add_all=add_all, share_all=share_all, aspect=aspect, 
                     label_mode=label_mode, cbar_mode=cbar_mode, cbar_location=cbar_location, 
                     cbar_pad=cbar_pad, cbar_size=cbar_size, axes_class=axes_class)
    # return figure and axes
    axes = tuple([ax for ax in grid]) # this is already flattened
    if lreduce and len(axes) == 1: axes = axes[0] # return a bare axes instance, if there is only one axes    
  else:
    # create axes using normal subplot routine
    if axes_pad is None: axes_pad = 0.03
    wspace = hspace = axes_pad
    if share_all: 
      sharex='all'; sharey='all'
    if sharex is True or sharex is None: sharex = 'col' # default
    if sharey is True or sharey is None: sharey = 'row'
    if sharex: hspace -= 0.015
    if sharey: wspace -= 0.015
    # create figure
    from matplotlib.pyplot import subplots    
    # GridSpec: http://matplotlib.org/users/gridspec.html 
    fig, axes = subplots(subplot[0], subplot[1], sharex=sharex, sharey=sharey,
                         squeeze=lreduce, facecolor='white', figsize=figsize, FigureClass=MyFigure)    
    # there is also a subplot_kw=dict() and fig_kw=dict()
    # just adjust margins
    margin_dict = dict(left=margins[0], bottom=margins[1], right=margins[0]+margins[2], 
                       top=margins[1]+margins[3], wspace=wspace, hspace=hspace)
    fig.subplots_adjust(**margin_dict)
  # add figure title
  if name is not None: fig.canvas.set_window_title(name) # window title
  if title is not None: fig.suptitle(title) # title on figure (printable)
  # return Figure/ImageGrid and tuple of axes
  #if AxesGrid: fig = grid # return ImageGrid instead of figure
  return fig, axes
Beispiel #26
0
def tiled_axis(ts, filename=None):
    fig = Figure((2.56 * 4, 2.56 * 4), 300)
    canvas = FigureCanvas(fig)
    #ax = fig.add_subplot(111)

    grid = ImageGrid(
        fig,
        111,  # similar to subplot(111)
        nrows_ncols=(3, 1),
        axes_pad=0.5,
        add_all=True,
        label_mode="L",
    )
    # pad half a day so that major ticks show up in the middle, not on the edges
    delta = dates.relativedelta(days=2, hours=12)
    # XXX: gets a list of days.
    timestamps = glucose.get_days(ts.time)
    xmin, xmax = (timestamps[0] - delta, timestamps[-1] + delta)

    fig.autofmt_xdate()

    def make_plot(ax, limit):

        preferspan = ax.axhspan(SAFE[0],
                                SAFE[1],
                                facecolor='g',
                                alpha=0.2,
                                edgecolor='#003333',
                                linewidth=1)

    def draw_glucose(ax, limit):
        xmin, xmax = limit
        # visualize glucose using stems
        ax.set_xlim([xmin, xmax])
        markers, stems, baselines = ax.stem(ts.time, ts.value, linefmt='b:')
        plt.setp(markers, color='red', linewidth=.5, marker='o')
        plt.setp(baselines, marker='None')

    def draw_title(ax, limit):
        ax.set_title('glucose history')

    def get_axis(ax, limit):
        xmin, xmax = limit
        ax.set_xlim([xmin, xmax])

        ax.grid(True)
        #ax.set_ylim( [ ts.value.min( ) *.85 , 600 ] )
        #ax.set_xlabel('time')

        majorLocator = dates.DayLocator()
        majorFormatter = dates.AutoDateFormatter(majorLocator)

        minorLocator = dates.HourLocator(interval=6)
        minorFormatter = dates.AutoDateFormatter(minorLocator)

        #ax.xaxis.set_major_locator(majorLocator)
        #ax.xaxis.set_major_formatter(majorFormatter)

        ax.xaxis.set_minor_locator(minorLocator)
        ax.xaxis.set_minor_formatter(minorFormatter)

        labels = ax.get_xminorticklabels()
        plt.setp(labels, rotation=30, fontsize='small')
        plt.setp(ax.get_xmajorticklabels(), rotation=30, fontsize='medium')

        xmin, xmax = ax.get_xlim()

        log.info(
            pformat({
                'xlim': [dates.num2date(xmin),
                         dates.num2date(xmax)],
                'xticks': dates.num2date(ax.get_xticks()),
            }))

    for i, day in enumerate(timestamps):
        ax = grid[i]
        get_axis(ax, [day, day + delta])
        name = '%s-%d.png' % (day.isoformat(), i)
        #fig.savefig( name )
        canvas.print_figure(name)
        # fig.clf()
        #make_plot( ax,

    #ax.set_ylabel('glucose mm/dL')
    return canvas
Beispiel #27
0
import matplotlib
matplotlib.rcParams['mathtext.fontset'] = 'cm'
matplotlib.rcParams['mathtext.rm'] = 'serif'

# n_snap = 0
for n_snap in range(0, 56):

    # Set up figure and image grid
    fig = plt.figure(figsize=(fig_width * n_cols, 10 * n_rows), )
    grid = ImageGrid(
        fig,
        111,  # as in plt.subplot(111)
        nrows_ncols=(n_rows, n_cols),
        axes_pad=0.2,
        share_all=True,
        cbar_location="right",
        cbar_mode="single",
        cbar_size="5%",
        cbar_pad=0.1,
    )

    colormap = 'turbo'
    alpha = 0.6

    x_min, x_max = -3, 5.5
    y_min, y_max = 1.8, 8.

    ax = grid[0]

    dens_points = data_all[n_snap]['dens_points']
Beispiel #28
0
y_test = np_utils.to_categorical(y_test, num_classes)

# matplotlib inline
#
cat_y = np.argmax(y_train, 1)
f, ax = plt.subplots()
n, bins, patches = ax.hist(cat_y, bins=range(11), align='left', rwidth=0.5)
ax.set_xticks(bins[:-1])
ax.set_xlabel('Class Id')
ax.set_ylabel('# Samples')
ax.set_title('CIFAR-10 Class Distribution')

fig = plt.figure(figsize=(10., 10.))
grid = ImageGrid(
    fig,
    111,  # similar to subplot(111)
    nrows_ncols=(10, 10),  # creates 10x10 grid of axes
    axes_pad=(0.03, 0.1),  # pad between axes in inch
)

for i in range(10):
    cls_id = i
    cat_im = X_train[cat_y == cls_id]

    for j in range(10):
        im = cat_im[j]
        im = im.squeeze()
        ax = grid[10 * i + j]
        ax.imshow(im, cmap='gray')
        ax.axis('off')
        ax.grid(False)
        print(j)
Beispiel #29
0
from orimat import qvec
from settings import e0, e1, e2, hist_bincenters, hist_binedges, a0_reciprocal

do_plotsurface = False  # set to True to plot surface orientation

runs = [4, 3, 8, 9, 7]
qarray = hist_bincenters[0] * a0_reciprocal
xmin, xmax = hist_binedges[1][0] * a0_reciprocal, hist_binedges[1][
    -1] * a0_reciprocal
ymin, ymax = hist_binedges[0][0] * a0_reciprocal, hist_binedges[0][
    -1] * a0_reciprocal

## plot the 2d images
fig = plt.figure(figsize=(6, 8))
grid = ImageGrid(fig, 111, nrows_ncols=(1,5), axes_pad=0.15, \
                 share_all=True, cbar_location='right', \
                 cbar_mode='single', cbar_size="15%", cbar_pad=0.15)
for i_run, run in enumerate(runs):
    ax = grid[i_run]
    filename = helper.getparam("filenames", run)
    sample = helper.getparam("samples", runs[i_run])
    orimat = np.load("orimats/%s.npy" % (helper.getparam("orimats", run)))
    data_orig = np.load("data_hklmat/%s_interpolated_xy.npy" % filename)
    im = ax.imshow(np.log10(data_orig), extent=[xmin, xmax, ymin, ymax], \
                   origin='lower', interpolation='nearest', \
                   vmax=3.0)

    if do_plotsurface:
        surf_norm = orimat.dot(qvec(50., 25., 90., 0.))
        surf_norm = surf_norm / np.linalg.norm(surf_norm)
        print "hkl of surface =", surf_norm
Beispiel #30
0
fig = plt.figure(figsize=(6, 6))

# Prepare images
Z = cbook.get_sample_data("axes_grid/bivariate_normal.npy", np_load=True)
extent = (-3, 4, -4, 3)
ZS = [Z[i::3, :] for i in range(3)]
extent = extent[0], extent[1] / 3., extent[2], extent[3]

# *** Demo 1: colorbar at each axes ***
grid = ImageGrid(
    fig,
    211,  # similar to subplot(211)
    nrows_ncols=(1, 3),
    axes_pad=0.05,
    label_mode="1",
    share_all=True,
    cbar_location="top",
    cbar_mode="each",
    cbar_size="7%",
    cbar_pad="1%",
)

for i, (ax, z) in enumerate(zip(grid, ZS)):
    im = ax.imshow(z, origin="lower", extent=extent)
    cb = ax.cax.colorbar(im)
    # Changing the colorbar ticks
    if i in [1, 2]:
        cb.set_ticks([-1, 0, 1])

for ax, im_title in zip(grid, ["Image 1", "Image 2", "Image 3"]):
    t = add_inner_title(ax, im_title, loc='lower left')
Beispiel #31
0
def create_channelmap(raster=None,
                      contour=None,
                      clevels=None,
                      zeropoint=0.,
                      channels=None,
                      ncols=4,
                      vrange=None,
                      vscale=None,
                      show=True,
                      pdfname=None,
                      cbarlabel=None,
                      cmap='RdYlBu_r',
                      beam=True,
                      **kwargs):
    """ Running this function will create a quick channel map of the Qube.
    One can either plot the contours or the raster image or both. This program
    can be used as a basis for a more detailed individual plot which can take
    better care of whitespace, etc. The following keywords are valid:

    raster:     qube object used to generate the raster image, leave blank or
                'None' if not desired.
    contour:    qube object used to generate the contours, leave blank or
                'None' if not desired.
    clevels:    nested lists of contour levels to draw, list should be the same
                length as the spectral dimension of the contour qube, if a
                single list is given assumes that the contours are the same
                for all channels.
    zeropoint:  Optional shift in velocities compared to the Restfrq keyword in
                the header of the Qube.
    ncols:      number of columns to plot
    vrange:     range of the z-axis (for consistency across panels) if None,
                will take minimum maximum of the Qube
    vscale:     division of the colorbar (if none will make ~5 divisions)
    cbarlabel:  if set, label of the color bar
    cmap:       colormap to use for the plotting
    pdfname:    if set, the name of the pdf file to which the image was saved

    Returns:
        fig
    """

    # generate a temporary qube from the data
    if raster is not None:
        tqube = raster
    elif contour is not None:
        tqube = contour
    else:
        raise ValueError('Need to define either a contour or raster image.')

    # first genererate the velocity array
    VelArr = tqube._getvelocity_() - zeropoint

    # define the number of channels, rows (and columns = already defined)
    if channels is None:
        channels = np.arange(tqube.shape[0])
    nrows = np.ceil(len(channels) / ncols).astype(int)

    # define the range of the z-axis (if needed)
    if vrange is None:
        vrange = [np.nanmin(tqube.data), np.nanmax(tqube.data)]
    # now generate a grid with the specified number of columns
    fig = plt.figure(1, (8., 8. / ncols * nrows))
    grid = ImageGrid(fig,
                     111,
                     nrows_ncols=(nrows, ncols),
                     axes_pad=0.,
                     cbar_mode='single',
                     cbar_location='right',
                     share_all=True)

    # now loop over the channels
    for idx, chan in enumerate(channels):

        # get the string value of the velocity
        VelStr = str(int(round(VelArr[chan]))) + ' km s$^{-1}$'

        # get the boolean of the beam (bottom left figure only)
        if (idx % ncols == 0) and (idx // ncols == int(nrows) - 1) and beam:
            beambool = True
        else:
            beambool = False

        # plot the individual channels
        if raster is not None:
            rasterimage = raster.get_slice(zindex=(chan, chan + 1))
        else:
            rasterimage = None
        if contour is not None:
            contourimage = contour.get_slice(zindex=(chan, chan + 1))
            # also get the contour levels
            if clevels is None:
                raise ValueError('Set the contour levels using the clevels ' +
                                 'keyword')
            elif type(clevels[0]) is list or type(clevels[0]) is np.ndarray:
                clevel = clevels[chan]
            else:
                clevel = clevels
        else:
            contourimage = None
            clevel = clevels

        standardfig(raster=rasterimage,
                    contour=contourimage,
                    clevels=clevel,
                    newplot=False,
                    fig=fig,
                    ax=grid[idx],
                    cbar=False,
                    beam=beambool,
                    vrange=vrange,
                    text=[VelStr],
                    cmap=cmap,
                    **kwargs)

    # now do the color bar
    norm = mpl.colors.Normalize(vmin=vrange[0], vmax=vrange[1])
    cmapo = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    cmapo.set_array([])
    if vscale is None:
        vscale = (vrange[1] - vrange[0]) / 5.
    cbr = plt.colorbar(cmapo,
                       cax=grid.cbar_axes[0],
                       ticks=np.arange(-10, 10) * vscale)

    # label of color bar
    if cbarlabel is not None:
        cbr.ax.set_ylabel(cbarlabel, labelpad=-1)

    if pdfname is not None:
        plt.savefig(pdfname, format='pdf', dpi=300)
    elif show:
        plt.show()
    else:
        pass

    return fig, grid, channels
Beispiel #32
0
def ransac_im_fit(im,
                  mode=1,
                  residual_threshold=0.1,
                  min_samples=10,
                  max_trials=1000,
                  model_f=None,
                  p0=None,
                  mask=None,
                  scale=False,
                  fract=1,
                  param_dict=None,
                  plot=False,
                  axes=(-2, -1),
                  widget=None):
    '''
    Fits a plane, polynomial, convex paraboloid, arbitrary function, or
    smoothing spline to an image using the RANSAC algorithm.

    Parameters
    ----------
    im : ndarray
        ndarray with images to fit to.
    mode : integer [0-4]
        Specifies model used for fit.
        0 is function defined by `model_f`.
        1 is plane.
        2 is quadratic.
        3 is concave paraboloid with offset.
        4 is smoothing spline.
    model_f : callable or None
        Function to be fitted.
        Definition is model_f(p, *args), where p is 1-D iterable of params
        and args is iterable of (x, y, z) arrays of point cloud coordinates.
        See examples.
    p0 : tuple
        Initial guess of fit params for `model_f`.
    mask : 2-D boolean array
        Array with which to mask data. True values are ignored.
    scale : bool
        If True, `residual_threshold` is scaled by stdev of `im`.
    fract : scalar (0, 1]
        Fraction of data used for fitting, chosen randomly. Non-used data
        locations are set as nans in `inliers`.
    residual_threshold : float
        Maximum distance for a data point to be classified as an inlier.
    min_samples : int or float
        The minimum number of data points to fit a model to.
        If an int, the value is the number of pixels.
        If a float, the value is a fraction (0.0, 1.0] of the total number of pixels.
    max_trials : int, optional
        Maximum number of iterations for random sample selection.
    param_dict : None or dictionary.
        If not None, the dictionary is passed to the model estimator.
        For arbitrary functions, this is passed to scipy.optimize.leastsq.
        For spline fitting, this is passed to scipy.interpolate.SmoothBivariateSpline.
        All other models take no parameters.
    plot : bool
        If True, the data, including inliers, model, etc are plotted.
    axes : length 2 iterable
        Indices of the input array with images.
    widget : QtWidget
        A qt widget that must contain as many widgets or dock widget as there is popup window

    Returns
    -------
    Tuple of fit, inliers, n, where:
    fit : 2-D array
        Image of fitted model.
    inliers : 2-D array
        Boolean array describing inliers.
    n : array or None
        Normal of plane fit. `None` for other models.

    Notes
    -----
    See skimage.measure.ransac for details of RANSAC algorithm.

    `min_samples` should be chosen appropriate to the size of the image
    and to the variation in the image.

    Increasing `residual_threshold` increases the fraction of the image
    fitted to.

    The entire image can be fitted to without RANSAC by setting:
    max_trials=1, min_samples=1.0, residual_threshold=`x`, where `x` is a
    suitably large value.

    Examples
    --------
    `model_f` for paraboloid with offset:

    >>> def model_f(p, *args):
    ...     (x, y, z) = args
    ...     m = np.abs(p[0])*((x-p[1])**2 + (y-p[2])**2) + p[3]
    ...     return m
    >>> p0 = (0.1, 10, 20, 0)


    To plot fit, inliers etc:

    >>> from fpd.ransac_tools import ransac_im_fit
    >>> import matplotlib as mpl
    >>> from numpy.ma import masked_where
    >>> import numpy as np
    >>> import matplotlib.pylab as plt
    >>> plt.ion()


    >>> cmap = mpl.cm.gray
    >>> cmap.set_bad('r')

    >>> image = np.random.rand(*(64,)*2)
    >>> fit, inliers, n = ransac_im_fit(image, mode=1)
    >>> cor_im = image-fit

    >>> pct = 0.5
    >>> vmin, vmax = np.percentile(cor_im, [pct, 100-pct])
    >>>
    >>> f, axs = plt.subplots(1, 4, sharex=True, sharey=True)
    >>> _ = axs[0].matshow(image, cmap=cmap)
    >>> _ = axs[1].matshow(masked_where(inliers==False, image), cmap=cmap)
    >>> _ = axs[2].matshow(fit, cmap=cmap)
    >>> _ = axs[3].matshow(cor_im, vmin=vmin, vmax=vmax)



    To plot plane normal vs threshold:

    >>> from fpd.ransac_tools import ransac_im_fit
    >>> from numpy.ma import masked_where
    >>> import numpy as np
    >>> from tqdm import tqdm
    >>> import matplotlib.pylab as plt
    >>> plt.ion()

    >>> image = np.random.rand(*(64,)*2)
    >>> ns = []
    >>> rts = np.logspace(0, 1.5, 5)
    >>> for rt in tqdm(rts):
    ...     nis = []
    ...     for i in range(64):
    ...         fit, inliers, n = ransac_im_fit(image, residual_threshold=rt, max_trials=10)
    ...         nis.append(n)
    ...     ns.append(np.array(nis).mean(0))
    >>> ns = np.array(ns)

    >>> thx = np.arctan2(ns[:,1], ns[:,2])
    >>> thy = np.arctan2(ns[:,0], ns[:,2])
    >>> thx = np.rad2deg(thx)
    >>> thy = np.rad2deg(thy)

    >>> _ = plt.figure()
    >>> _ = plt.semilogx(rts, thx)
    >>> _ = plt.semilogx(rts, thy)

    '''

    # Set model
    # Functions defining classes are needed to pass parameters since class must
    # not be instantiated or are monkey patched (only in spline implementation)
    if mode == 0:
        # generate model_class with passed function
        if p0 is None:
            raise NotImplementedError('p0 must be specified.')
        model_class = rt._model_class_gen(model_f, p0, param_dict)
    elif mode == 1:
        # linear
        model_class = rt._Plane3dModel
    elif mode == 2:
        # quadratic
        model_class = rt._Poly3dModel
    elif mode == 3:
        # concave paraboloid
        model_class = rt._Poly3dParaboloidModel
    elif mode == 4:
        # spline
        class _Spline3dModel_monkeypatched(_Spline3dModel):
            pass

        model_class = _Spline3dModel_monkeypatched
        model_class.param_dict = param_dict

    multiim = False
    if im.ndim > 2:
        multiim = True
        axes = [int(el) for el in axes]
        ims, unflat_shape = seq_image_array(im, axes)
        pbar = tqdm(total=ims.shape[0])
    else:
        ims = im[None]

    fits = []
    inlierss = []
    ns = []

    for imi in ims:
        # set data
        yy, xx = np.indices(imi.shape)
        zz = imi
        if mask is None:
            keep = (np.ones_like(imi) == 1).flatten()
        else:
            keep = (mask == False).flatten()
        data = np.column_stack([xx.flat[keep], yy.flat[keep], zz.flat[keep]])

        if type(min_samples) is int:
            # take number directly
            pass
        else:
            # take number as fraction
            min_samples = int(len(keep) * min_samples)
            print("min_samples is set to: %d" % (min_samples))

        # randomly select data
        sel = np.random.rand(data.shape[0]) <= fract
        data = data[sel.flatten()]

        # scale residual, if chosen
        if scale:
            residual_threshold = residual_threshold * np.std(data[:, 2])

        # determine if fitting to all
        full_fit = min_samples == data.shape[0]

        if not full_fit:
            # do ransac fit
            model, inliers = ransac(data=data,
                                    model_class=model_class,
                                    min_samples=min_samples,
                                    residual_threshold=residual_threshold,
                                    max_trials=max_trials)
        else:
            model = model_class()
            inliers = np.ones(data.shape[0]) == 1

        # get params from fit with all inliers
        model.estimate(data[inliers])
        # get model over all x, y
        args = (xx.flatten(), yy.flatten(), zz.flatten())
        fit = model.my_model(model.params, *args).reshape(imi.shape)

        if mask is None and fract == 1:
            inliers = inliers.reshape(imi.shape)
        else:
            inliers_nans = np.empty_like(imi).flatten()
            inliers_nans[:] = np.nan
            yi = np.indices(inliers_nans.shape)[0]

            sel_fit = yi[keep][sel.flatten()]
            inliers_nans[sel_fit] = inliers
            inliers = inliers_nans.reshape(imi.shape)

        # calculate normal for plane
        if mode == 1:
            # linear
            C = model.params
            n = np.array([-C[0], -C[1], 1])
            n_mag = np.linalg.norm(n, ord=None, axis=0)
            n = n / n_mag
        else:
            # non-linear
            n = None

        if plot:
            import matplotlib.pylab as plt
            import matplotlib as mpl
            from numpy.ma import masked_where
            from mpl_toolkits.axes_grid1 import ImageGrid

            plt.ion()
            cmap = mpl.cm.gray
            cmap.set_bad('r')

            cor_im = imi - fit
            pct = 0.1
            vmin, vmax = np.percentile(cor_im, [pct, 100 - pct])
            if widget is None:
                fig = plt.figure()
            else:
                docked = widget.setup_docking("Circular Center", "Bottom")
                fig = docked.get_fig()
                fig.clf()

            grid = ImageGrid(fig,
                             111,
                             nrows_ncols=(1, 4),
                             axes_pad=0.1,
                             share_all=True,
                             label_mode="L",
                             cbar_location="right",
                             cbar_mode="single")

            images = [imi, masked_where(inliers == False, imi), fit, cor_im]
            titles = ['Image', 'Inliers', 'Fit', 'Corrected']
            for i, image in enumerate(images):
                img = grid[i].imshow(image, cmap=cmap, interpolation='nearest')
                grid[i].set_title(titles[i])
            img.set_clim(vmin, vmax)
            grid.cbar_axes[0].colorbar(img)

            #f, axs = plt.subplots(1, 4, sharex=True, sharey=True)
            #_ = axs[0].matshow(imi, cmap=cmap)
            #_ = axs[1].matshow(masked_where(inliers==False, imi), cmap=cmap)
            #_ = axs[2].matshow(fit, cmap=cmap)
            #_ = axs[3].matshow(cor_im, vmin=vmin, vmax=vmax)

            # for i, title in enumerate(['Image' , 'Inliers', 'Fit', 'Corrected']):
            # axs[i].set_title(title)
            # plt.tight_layout()
        fits.append(fit)
        inlierss.append(inliers)
        ns.append(n)
        if multiim:
            pbar.update(1)
    fit = np.array(fits)
    inliers = np.array(inlierss)
    n = np.array(ns)

    if multiim:
        pbar.close()

        # reshape
        fit = unseq_image_array(fit, axes, unflat_shape)
        inliers = unseq_image_array(inliers, axes, unflat_shape)
        n = unseq_image_array(n, axes, unflat_shape)
    else:
        fit = fit[0]
        inliers = inliers[0]
        n = n[0]

    return (fit, inliers, n)
Beispiel #33
0
def modeldata_comparison(cube, pdffile=None, **kwargs):
    """This will make a six-panel plot for providing an easy overview of the
    the model and the data. Panels are the zeroth moment for the model,
    the data and the residual, the first moment for the model and the data,
    and the residual in the velocity field.

    inputs:
    -------
    cube: Qubefit object that holds the data and the model

    keywords:
    ---------
    pdffile (string|default: None): If set, will save the image to a pdf file

    showmask (Bool|default: False): If set, it will show a contour of the mask
        used in creating the fit (stored in cube.maskarray).
    """

    # create the figure plots
    fig = plt.figure(1, (8., 8.))
    grid = ImageGrid(fig, 111, nrows_ncols=(2, 3), axes_pad=0)

    # create the data, model and residual cubes
    dqube = dc(cube)
    mqube = dc(cube)
    mqube.data = cube.model
    rqube = dc(cube)
    rqube.data = cube.data - cube.model

    # moment-zero data
    dMom0 = dqube.calculate_moment(moment=0)
    Mom0sig = (np.sqrt(np.nansum(cube.variance[:, 0, 0])) *
               cube.__get_velocitywidth__())
    clevels = np.insert(np.arange(3, 30, 3), 0, np.arange(-30, 0, 3)) * Mom0sig
    vrangemom = [-3 * Mom0sig, 11 * Mom0sig]
    mask = dMom0.mask_region(value=Mom0sig * 3, applymask=False)
    standardfig(raster=dMom0,
                contour=dMom0,
                clevels=clevels,
                ax=grid[0],
                fig=fig,
                vrange=vrangemom,
                cmap='RdYlBu_r',
                text='Data',
                textprop=[dict(size=12)],
                **kwargs)

    # moment-zero model
    mMom0 = mqube.calculate_moment(moment=0)
    standardfig(raster=mMom0,
                contour=mMom0,
                clevels=clevels,
                ax=grid[1],
                fig=fig,
                vrange=vrangemom,
                cmap='RdYlBu_r',
                text='Model',
                textprop=[dict(size=12)],
                **kwargs)

    # moment-zero residual
    rMom0 = rqube.calculate_moment(moment=0)
    standardfig(raster=rMom0,
                contour=rMom0,
                clevels=clevels,
                ax=grid[2],
                fig=fig,
                vrange=vrangemom,
                cmap='RdYlBu_r',
                text='Residual',
                textprop=[dict(size=12)],
                **kwargs)

    # moment-one data
    dsqube = dqube.mask_region(mask=mask)
    dMom1 = dsqube.calculate_moment(moment=1)
    vrangemom1 = [0.95 * np.nanmin(dMom1.data), 0.95 * np.nanmax(dMom1.data)]
    standardfig(raster=dMom1,
                ax=grid[3],
                fig=fig,
                cmap='Spectral_r',
                vrange=vrangemom1,
                text='Data',
                textprop=[dict(size=12)],
                **kwargs)

    msqube = mqube.mask_region(mask=mask)
    mMom1 = msqube.calculate_moment(moment=1)
    standardfig(raster=mMom1,
                ax=grid[4],
                fig=fig,
                cmap='Spectral_r',
                vrange=vrangemom1,
                beam=False,
                text='Model',
                textprop=[dict(size=12)],
                **kwargs)

    rMom1 = dc(mMom1)
    rMom1.data = dMom1.data - mMom1.data
    standardfig(raster=rMom1,
                ax=grid[5],
                fig=fig,
                cmap='Spectral_r',
                vrange=vrangemom1,
                beam=False,
                text='Residual',
                textprop=[dict(size=12)],
                **kwargs)
Beispiel #34
0
class Plot(object):

    def __init__(self, kind='', figsize=None, nrows=1, ncols=1, rect=111,
                 cbar_mode='single', squeeze=False, **kwargs):
        self._create_subplots(kind=kind, figsize=figsize, nrows=nrows,
            ncols=ncols, **kwargs)

    def _create_subplots(self, kind='', figsize=None, nrows=1, ncols=1, rect=111,
        cbar_mode='single', squeeze=False, **kwargs):
        """
        :Kwargs:
            - kind (str, default: '')
                The kind of plot. For plotting matrices or images
                (`matplotlib.pyplot.imshow`), choose `matrix`, otherwise leave
                blank.
            - figsize (tuple, defaut: None)
                Size of the figure.
            - nrows_ncols (tuple, default: (1, 1))
                Shape of subplot arrangement.
            - **kwargs
                A dictionary of keyword arguments that `matplotlib.ImageGrid`
                or `matplotlib.pyplot.suplots` accept. Differences:
                    - `rect` (`matplotlib.ImageGrid`) is a keyword argument here
                    - `cbar_mode = 'single'`
                    - `squeeze = False`
        :Returns:
            `matplotlib.pyplot.figure` and a grid of axes.
        """

        if 'nrows_ncols' not in kwargs:
            nrows_ncols = (nrows, ncols)
        else:
            nrows_ncols = kwargs['nrows_ncols']
            del kwargs['nrows_ncols']
        try:
            num = self.fig.number
            self.fig.clf()
        except:
            num = None
        if kind == 'matrix':
            self.fig = self.figure(figsize=figsize, num=num)
            self.axes = ImageGrid(self.fig, rect,
                                  nrows_ncols=nrows_ncols,
                                  cbar_mode=cbar_mode,
                                  **kwargs
                                  )
        else:
            self.fig, self.axes = plt.subplots(
                nrows=nrows_ncols[0],
                ncols=nrows_ncols[1],
                figsize=figsize,
                squeeze=squeeze,
                num=num,
                **kwargs
                )
            self.axes = self.axes.ravel()  # turn axes into a list
        self.kind = kind
        self.subplotno = -1  # will get +1 after the plot command
        self.nrows_ncols = nrows_ncols
        return (self.fig, self.axes)

    def __getattr__(self, name):
        """Pass on a `matplotlib` function that we haven't modified
        """
        def method(*args, **kwargs):
            return getattr(plt, name)(*args, **kwargs)

        try:
            return method  # is it a function?
        except TypeError:  # so maybe it's just a self variable
            return getattr(self, name)

    def __getitem__(self, key):
        """Allow to get axes as Plot()[key]
        """
        if key > len(self.axes):
            raise IndexError
        if key < 0:
            key += len(self.axes)
        return self.axes[key]

    def get_ax(self, subplotno=None):
        """
        Returns the current or the requested axis from the current figure.

        :note: The :class:`Plot()` is indexable so you should access axes as
        `Plot()[key]` unless you want to pass a list like (row, col).

        :Kwargs:
            subplotno (int, default: None)
                Give subplot number explicitly if you want to get not the
                current axis

        :Returns:
            ax
        """
        if subplotno is None:
            no = self.subplotno
        else:
            no = subplotno

        if isinstance(no, int):
            ax = self.axes[no]
        else:
            if no[0] < 0: no += len(self.axes._nrows)
            if no[1] < 0: no += len(self.axes._ncols)

            if isinstance(self.axes, ImageGrid):  # axes are a list
                if self.axes._direction == 'row':
                    no = self.axes._ncols * no[0] + no[1]
                else:
                    no = self.axes._nrows * no[0] + no[1]
            else:  # axes are a grid
                no = self.axes._ncols * no[0] + no[1]
            ax = self.axes[no]

        return ax

    def next(self):
        """
        Returns the next axis.

        This is useful when a plotting function is not implemented by
        :mod:`plot` and you have to instead rely on matplotlib's plotting
        which does not advance axes automatically.
        """
        self.subplotno += 1
        return self.get_ax()

    def sample_paired(self, ncolors=2):
        """
        Returns colors for matplotlib.cm.Paired.
        """
        if ncolors <= 12:
            colors_full = [mpl.cm.Paired(i * 1. / 11) for i in range(1, 12, 2)]
            colors_pale = [mpl.cm.Paired(i * 1. / 11) for i in range(10, -1, -2)]
            colors = colors_full + colors_pale
            return colors[:ncolors]
        else:
            return [mpl.cm.Paired(c) for c in np.linspace(0,ncolors)]

    def get_colors(self, ncolors=2, cmap='Paired'):
        """
        Get a list of nice colors for plots.

        FIX: This function is happy to ignore the ugly settings you may have in
        your matplotlibrc settings.
        TODO: merge with mpltools.color

        :Kwargs:
            ncolors (int, default: 2)
                Number of colors required. Typically it should be the number of
                entries in the legend.
            cmap (str or matplotlib.cm, default: 'Paired')
                A colormap to sample from when ncolors > 12

        :Returns:
            a list of colors
        """
        colorc = plt.rcParams['axes.color_cycle']
        if ncolors < len(colorc):
            colors = colorc[:ncolors]
        elif ncolors <= 12:
            colors = self.sample_paired(ncolors=ncolors)
        else:
            thisCmap = mpl.cm.get_cmap(cmap)
            norm = mpl.colors.Normalize(0, 1)
            z = np.linspace(0, 1, ncolors + 2)
            z = z[1:-1]
            colors = thisCmap(norm(z))
        return colors

    def pivot_plot(self,df,rows=None,cols=None,values=None,yerr=None,
                   **kwargs):
        agg = self.aggregate(df, rows=rows, cols=cols,
                                 values=values, yerr=yerr)
        if yerr is None:
            no_yerr = True
        else:
            no_yerr = False
        return self._plot(agg, no_yerr=no_yerr,**kwargs)


    def _plot(self, agg, ax=None,
                   title='', kind='bar', xtickson=True, ytickson=True,
                   no_yerr=False, numb=False, autoscale=True, **kwargs):
        """DEPRECATED plotting function"""
        print "plot._plot() has been DEPRECATED; please don't use it anymore"
        self.plot(agg, ax=ax,
                   title=title, kind=kind, xtickson=xtickson, ytickson=ytickson,
                   no_yerr=no_yerr, numb=numb, autoscale=autoscale, **kwargs)

    def plot(self, agg, subplots=None, **kwargs):
        """
        The main plotting function.

        :Args:
            agg (`pandas.DataFrame` or similar)
                A structured input, preferably a `pandas.DataFrame`, but in
                principle accepts anything that can be converted into it.

        :Kwargs:
            - subplots (None, True, or False; default=None)
                Whether you want to split data into subplots or not. If True,
                the top level is treated as a subplot. If None, detects
                automatically based on `agg.columns.names` -- the first entry
                to start with `subplots.` will be used. This is the default
                output from `stats.aggregate` and is recommended.
            - **kwargs
                Keyword arguments for plotting

        :Returns:
            A list of axes of all plots.
        """
        agg = pandas.DataFrame(agg)
        axes = []
        try:
            s_idx = [s for s,n in enumerate(agg.columns.names) if n.startswith('subplots.')]
        except:
            s_idx = None
        if s_idx is not None:  # subplots implicit in agg
            if len(s_idx) != 0:
                sbp = agg.columns.levels[s_idx[0]]
            else:
                sbp = None
        elif subplots:  # get subplots from the top level column
            sbp = agg.columns.levels[0]
        else:
            sbp = None

        if sbp is None:
            axes = [self._plot_ax(agg, **kwargs)]
        else:
            # if haven't made any plots yet...
            if self.subplotno == -1:
                num_subplots = len(sbp)
                # ...can still adjust the number of subplots
                if num_subplots > len(self.axes):
                    self._create_subplots(ncols=num_subplots)

            for no, subname in enumerate(sbp):
                # all plots are the same, onle legend will suffice
                if subplots is None or subplots:
                    if no == 0:
                        legend = True
                    else:
                        legend = False
                else:  # plots vary; each should get a legend
                    legend = True
                ax = self._plot_ax(agg[subname], title=subname, legend=legend,
                                **kwargs)
                if 'title' in kwargs:
                    ax.set_title(kwargs['title'])
                else:
                    ax.set_title(subname)
                axes.append(ax)
        return axes

    def _plot_ax(self, agg, ax=None,
                   title='', kind='bar', legend=True,
                   xtickson=True, ytickson=True,
                   no_yerr=False, numb=False, autoscale=True, order=None,
                   **kwargs):
        if ax is None:
            self.subplotno += 1
            ax = self.get_ax()
        if isinstance(agg, pandas.DataFrame):
            mean, p_yerr = self.errorbars(agg)
        else:
            mean = agg
            p_yerr = np.zeros((len(agg), 1))

        if mean.index.nlevels == 1:  # oops, nothing to unstack
            mean = pandas.DataFrame(mean).T
            p_yerr = pandas.DataFrame(p_yerr).T
        else:
            # make columns which will turn into legend entries
            for name in agg.columns.names:
                if name.startswith('cols.'):
                    mean = mean.unstack(level=name)
                    p_yerr = p_yerr.unstack(level=name)

        if isinstance(agg, pandas.Series) and kind=='bean':
            kind = 'bar'
            print 'WARNING: Beanplot not available for a single measurement'

        if kind == 'bar':
            self.barplot(mean, yerr=p_yerr, ax=ax)
        elif kind == 'line':
            self.lineplot(mean, yerr=p_yerr, ax=ax)
        elif kind == 'bean':
            autoscale = False  # FIX: autoscaling is incorrect on beanplots
            #if len(mean.columns) <= 2:
            ax = self.beanplot(agg, ax=ax, order=order, **kwargs)#, pos=range(len(mean.index)))
            #else:
                #raise Exception('Beanplot is not available for more than two '
                                #'classes.')
        else:
            raise Exception('%s plot not recognized. Choose from '
                            '{bar, line, bean}.' %kind)

        # TODO: xticklabel rotation business is too messy
        if 'xticklabels' in kwargs:
            ax.set_xticklabels(kwargs['xticklabels'], rotation=0)
        if not xtickson:
            ax.set_xticklabels(['']*len(ax.get_xticklabels()))

        labels = ax.get_xticklabels()
        max_len = max([len(label.get_text()) for label in labels])
        for label in labels:
            if max_len > 20:
                label.set_rotation(90)
            else:
                label.set_rotation(0)
            #label.set_size('x-large')
        #ax.set_xticklabels(labels, rotation=0, size='x-large')

        if not ytickson:
            ax.set_yticklabels(['']*len(ax.get_yticklabels()))
        ax.set_xlabel('')

        # set y-axis limits
        if 'ylim' in kwargs:
            ax.set_ylim(kwargs['ylim'])
        elif autoscale:
            mean_array = np.asarray(mean)
            r = np.max(mean_array) - np.min(mean_array)
            ebars = np.where(np.isnan(p_yerr), r/3., p_yerr)
            if kind == 'bar':
                ymin = np.min(np.asarray(mean) - ebars)
                if ymin > 0:
                    ymin = 0
                else:
                    ymin = np.min(np.asarray(mean) - 3*ebars)
            else:
                ymin = np.min(np.asarray(mean) - 3*ebars)
            if kind == 'bar':
                ymax = np.max(np.asarray(mean) + ebars)
                if ymax < 0:
                    ymax = 0
                else:
                    ymax = np.max(np.asarray(mean) + 3*ebars)
            else:
                ymax = np.max(np.asarray(mean) + 3*ebars)
            ax.set_ylim([ymin, ymax])

        # set x and y labels
        if 'xlabel' in kwargs:
            ax.set_xlabel(kwargs['xlabel'])
        else:
            ax.set_xlabel(self._get_title(mean, 'rows'))
        if 'ylabel' in kwargs:
            ax.set_ylabel(kwargs['ylabel'])
        else:
            ax.set_ylabel(self._get_title(mean, 'cols'))

        # set x tick labels
        #FIX: data.index returns float even if it is int because dtype=object
        #if len(mean.index) == 1:  # no need to put a label for a single bar group
            #ax.set_xticklabels([''])
        #else:
        ax.set_xticklabels(mean.index.tolist())

        ax.set_title(title)
        self._draw_legend(ax, visible=legend, data=mean, **kwargs)
        if numb == True:
            self.add_inner_title(ax, title='%s' % self.subplotno, loc=2)

        return ax

    def _get_title(self, data, pref):
        if pref == 'cols':
            dnames = data.columns.names
        else:
            dnames = data.index.names
        title = [n.split('.',1)[1] for n in dnames if n.startswith(pref+'.')]

        title = ', '.join(title)
        return title

    def _draw_legend(self, ax, visible=True, data=None, **kwargs):
        l = ax.get_legend()  # get an existing legend
        if l is None:  # create a new legend
            l = ax.legend()
        l.legendPatch.set_alpha(0.5)
        l.set_title(self._get_title(data, 'cols'))

        if 'legend_visible' in kwargs:
            l.set_visible(kwargs['legend_visible'])
        elif visible is not None:
            l.set_visible(visible)
        else:  #decide automatically
            if len(l.texts) == 1:  # showing a single legend entry is useless
                l.set_visible(False)
            else:
                l.set_visible(True)

    def hide_plots(self, nums):
        """
        Hides an axis.

        :Args:
            nums (int, tuple or list of ints)
                Which axes to hide.
        """
        if isinstance(nums, int) or isinstance(nums, tuple):
            nums = [nums]
        for num in nums:
            ax = self.get_ax(num)
            ax.axis('off')

    def barplot(self, data, yerr=None, ax=None):
        """
        Plots a bar plot.

        :Args:
            data (`pandas.DataFrame` or any other array accepted by it)
                A data frame where rows go to the x-axis and columns go to the
                legend.

        """
        data = pandas.DataFrame(data)
        if yerr is None:
            yerr = np.empty(data.shape)
            yerr = yerr.reshape(data.shape)  # force this shape
            yerr = np.nan
        if ax is None:
            self.subplotno += 1
            ax = self.get_ax()

        colors = self.get_colors(len(data.columns))

        n = len(data.columns)
        idx = np.arange(len(data))
        width = .75 / n
        rects = []
        for i, (label, column) in enumerate(data.iteritems()):
            rect = ax.bar(idx+i*width+width/2, column, width, label=str(label),
                yerr=yerr[label].tolist(), color = colors[i], ecolor='black')
            # TODO: yerr indexing might need fixing
            rects.append(rect)
        ax.set_xticks(idx + width*n/2 + width/2)
        ax.legend(rects, data.columns.tolist())

        return ax

    def lineplot(self, data, yerr=None, ax=None):
        """
        Plots a bar plot.

        :Args:
            data (`pandas.DataFrame` or any other array accepted by it)
                A data frame where rows go to the x-axis and columns go to the
                legend.

        """
        data = pandas.DataFrame(data)
        if yerr is None:
            yerr = np.empty(data.shape)
            yerr = yerr.reshape(data.shape)  # force this shape
            yerr = np.nan
        if ax is None:
            self.subplotno += 1
            ax = self.get_ax()

        #colors = self.get_colors(len(data.columns))

        x = range(len(data))
        lines = []
        for i, (label, column) in enumerate(data.iteritems()):
            line = ax.plot(x, column, label=str(label))
            lines.append(line)
            ax.errorbar(x, column, yerr=yerr[label].tolist(), fmt=None,
                ecolor='black')
        #ticks = ax.get_xticks().astype(int)
        #if ticks[-1] >= len(data.index):
            #labels = data.index[ticks[:-1]]
        #else:
            #labels = data.index[ticks]
        #ax.set_xticklabels(labels)
        #ax.legend()
        #loc='center left', bbox_to_anchor=(1.3, 0.5)
        #loc='upper right', frameon=False
        return ax

    def scatter(self, x, y, ax=None, labels=None, title='', **kwargs):
        """
        Draws a scatter plot.

        This is very similar to `matplotlib.pyplot.scatter` but additionally
        accepts labels (for labeling points on the plot), plot title, and an
        axis where the plot should be drawn.

        :Args:
            - x (an iterable object)
                An x-coordinate of data
            - y (an iterable object)
                A y-coordinate of data

        :Kwargs:
            - ax (default: None)
                An axis to plot in.
            - labels (list of str, default: None)
                A list of labels for each plotted point
            - title (str, default: '')
                Plot title
            - ** kwargs
                Additional keyword arguments for `matplotlib.pyplot.scatter`

        :Return:
            Current axis for further manipulation.

        """
        if ax is None:
            self.subplotno += 1
            ax = self.get_ax()
        plt.rcParams['axes.color_cycle']
        ax.scatter(x, y, marker='o', color=self.get_colors()[0], **kwargs)
        if labels is not None:
            for c, (pointx, pointy) in enumerate(zip(x,y)):
                ax.text(pointx, pointy, labels[c])
        ax.set_title(title)
        return ax

    def matrix_plot(self, matrix, ax=None, title='', **kwargs):
        """
        Plots a matrix.

        .. warning:: Not tested yet

        :Args:
            matrix

        :Kwargs:
            - ax (default: None)
                An axis to plot on.
            - title (str, default: '')
                Plot title
            - **kwargs
                Keyword arguments to pass to `matplotlib.pyplot.imshow`

        """
        if ax is None:
            ax = plt.subplot(111)
        import matplotlib.colors
        norm = matplotlib.colors.normalize(vmax=1, vmin=0)
        mean, sem = self.errorbars(matrix)
        #matrix = pandas.pivot_table(mean.reset_index(), rows=)
        im = ax.imshow(mean, norm=norm, interpolation='none', **kwargs)
        # ax.set_title(title)

        ax.cax.colorbar(im)#, ax=ax, use_gridspec=True)
        # ax.cax.toggle_label(True)

        t = self.add_inner_title(ax, title, loc=2)
        t.patch.set_ec("none")
        t.patch.set_alpha(0.8)
        xnames = ['|'.join(map(str,label)) for label in matrix.minor_axis]
        ax.set_xticks(range(len(xnames)))
        ax.set_xticklabels(xnames)
        # rotate long labels
        if max([len(n) for n in xnames]) > 20:
            ax.axis['bottom'].major_ticklabels.set_rotation(90)
        ynames = ['|'.join(map(str,label)) for label in matrix.major_axis]
        ax.set_yticks(range(len(ynames)))
        ax.set_yticklabels(ynames)
        return ax

    def add_inner_title(self, ax, title, loc=2, size=None, **kwargs):
        from matplotlib.offsetbox import AnchoredText
        from matplotlib.patheffects import withStroke
        if size is None:
            size = dict(size=plt.rcParams['legend.fontsize'])
        at = AnchoredText(title, loc=loc, prop=size,
                          pad=0., borderpad=0.5,
                          frameon=False, **kwargs)
        ax.add_artist(at)
        at.txt._text.set_path_effects([withStroke(foreground="w", linewidth=3)])
        return at

    def errorbars(self, df, yerr_type='sem'):
        # Set up error bar information
        if yerr_type == 'sem':
            mean = df.mean()  # mean across items
            # std already has ddof=1
            sem = df.std() / np.sqrt(len(df))
            #yerr = np.array(sem)#.reshape(mean.shape)  # force this shape
        elif yerr_type == 'binomial':
            pass
            # alpha = .05
            # z = stats.norm.ppf(1-alpha/2.)
            # count = np.mean(persubj, axis=1, ddof=1)
            # p_yerr = z*np.sqrt(mean*(1-mean)/persubj.shape[1])

        return mean, sem

    def stats_test(self, agg, test='ttest'):
        d = agg.shape[0]

        if test == 'ttest':
            # 2-tail T-Test
            ttest = (np.zeros((agg.shape[1]*(agg.shape[1]-1)/2, agg.shape[2])),
                     np.zeros((agg.shape[1]*(agg.shape[1]-1)/2, agg.shape[2])))
            ii = 0
            for c1 in range(agg.shape[1]):
                for c2 in range(c1+1,agg.shape[1]):
                    thisTtest = stats.ttest_rel(agg[:,c1,:], agg[:,c2,:], axis = 0)
                    ttest[0][ii,:] = thisTtest[0]
                    ttest[1][ii,:] = thisTtest[1]
                    ii += 1
            ttestPrint(title = '**** 2-tail T-Test of related samples ****',
                values = ttest, plotOpt = plotOpt,
                type = 2)

        elif test == 'ttest_1samp':
            # One-sample t-test
            m = .5
            oneSample = stats.ttest_1samp(agg, m, axis = 0)
            ttestPrint(title = '**** One-sample t-test: difference from %.2f ****' %m,
                values = oneSample, plotOpt = plotOpt, type = 1)

        elif test == 'binomial':
            # Binomial test
            binom = np.apply_along_axis(stats.binom_test,0,agg)
            print binom
            return binom


    def ttestPrint(self, title = '****', values = None, xticklabels = None, legend = None, bon = None):

        d = 8
        # check if there are any negative t values (for formatting purposes)
        if np.any([np.any(val < 0) for val in values]): neg = True
        else: neg = False

        print '\n' + title
        for xi, xticklabel in enumerate(xticklabels):
            print xticklabel

            maxleg = max([len(leg) for leg in legend])
#            if type == 1: legendnames = ['%*s' %(maxleg,p) for p in plotOpt['subplot']['legend.names']]
#            elif type == 2:
            pairs = q.combinations(legend,2)
            legendnames = ['%*s' %(maxleg,p[0]) + ' vs ' + '%*s' %(maxleg,p[1]) for p in pairs]
            #print legendnames
            for yi, legendname in enumerate(legendnames):
                if values[0].ndim == 1:
                    t = values[0][xi]
                    p = values[1][xi]
                else:
                    t = values[0][yi,xi]
                    p = values[1][yi,xi]
                if p < .001/bon: star = '***'
                elif p < .01/bon: star = '**'
                elif p < .05/bon: star = '*'
                else: star = ''

                if neg and t > 0:
                    outputStr = '    %(s)s: t(%(d)d) =  %(t).3f, p = %(p).3f %(star)s'
                else:
                    outputStr = '    %(s)s: t(%(d)d) = %(t).3f, p = %(p).3f %(star)s'

                print outputStr \
                    %{'s': legendname, 'd':(d-1), 't': t,
                    'p': p, 'star': star}

    def mds(self, results, labels, fonts='freesansbold.ttf', title='',
        ax = None):
        """Plots Multidimensional scaling results"""
        if ax is None:
            try:
                row = self.subplotno / self.axes[0][0].numCols
                col = self.subplotno % self.axes[0][0].numCols
                ax = self.axes[row][col]
            except:
                ax = self.axes[self.subplotno]
        ax.set_title(title)
        # plot each point with a name
        dims = results.ndim
        try:
            if results.shape[1] == 1:
                dims = 1
        except:
            pass
        if dims == 1:
            df = pandas.DataFrame(results, index=labels, columns=['data'])
            df = df.sort(columns='data')
            self._plot(df)
        elif dims == 2:
            for c, coord in enumerate(results):
                ax.plot(coord[0], coord[1], 'o', color=mpl.cm.Paired(.5))
                ax.text(coord[0], coord[1], labels[c], fontproperties=fonts[c])
        else:
            print 'Cannot plot more than 2 dims'


    def _violinplot(self, data, pos, rlabels, ax=None, bp=False, cut=None, **kwargs):
        """
        Make a violin plot of each dataset in the `data` sequence.

        Based on `code by Teemu Ikonen
        <http://matplotlib.1069221.n5.nabble.com/Violin-and-bean-plots-tt27791.html>`_
        which was based on `code by Flavio Codeco Coelho
        <http://pyinsci.blogspot.com/2009/09/violin-plot-with-matplotlib.html>`)
        """
        def draw_density(p, low, high, k1, k2, ncols=2):
            m = low #lower bound of violin
            M = high #upper bound of violin
            x = np.linspace(m, M, 100) # support for violin
            v1 = k1.evaluate(x) # violin profile (density curve)
            v1 = w*v1/v1.max() # scaling the violin to the available space
            v2 = k2.evaluate(x) # violin profile (density curve)
            v2 = w*v2/v2.max() # scaling the violin to the available space

            if ncols == 2:
                ax.fill_betweenx(x, -v1 + p, p, facecolor='black', edgecolor='black')
                ax.fill_betweenx(x, p, p + v2, facecolor='grey', edgecolor='gray')
            else:
                ax.fill_betweenx(x, -v1 + p, p + v2, facecolor='black', edgecolor='black')


        if pos is None:
            pos = [0,1]
        dist = np.max(pos)-np.min(pos)
        w = min(0.15*max(dist,1.0),0.5) * .5

        #for major_xs in range(data.shape[1]):
        for num, rlabel in enumerate(rlabels):
            p = pos[num]
            d1 = data.ix[rlabel, 0]
            k1 = scipy.stats.gaussian_kde(d1) #calculates the kernel density
            if data.shape[1] == 1:
                d2 = d1
                k2 = k1
            else:
                d2 = data.ix[rlabel, 1]
                k2 = scipy.stats.gaussian_kde(d2) #calculates the kernel density
            cutoff = .001
            if cut is None:
                upper = max(d1.max(),d2.max())
                lower = min(d1.min(),d2.min())
                stepsize = (upper - lower) / 100
                area_low1 = 1  # max cdf value
                area_low2 = 1  # max cdf value
                low = min(d1.min(), d2.min())
                while area_low1 > cutoff or area_low2 > cutoff:
                    area_low1 = k1.integrate_box_1d(-np.inf, low)
                    area_low2 = k2.integrate_box_1d(-np.inf, low)
                    low -= stepsize
                    #print area_low, low, '.'
                area_high1 = 1  # max cdf value
                area_high2 = 1  # max cdf value
                high = max(d1.max(), d2.max())
                while area_high1 > cutoff or area_high2 > cutoff:
                    area_high1 = k1.integrate_box_1d(high, np.inf)
                    area_high2 = k2.integrate_box_1d(high, np.inf)
                    high += stepsize
            else:
                low, high = cut

            draw_density(p, low, high, k1, k2, ncols=data.shape[1])


        # a work-around for generating a legend for the PolyCollection
        # from http://matplotlib.org/users/legend_guide.html#using-proxy-artist
        left = Rectangle((0, 0), 1, 1, fc="black", ec='black')
        right = Rectangle((0, 0), 1, 1, fc="gray", ec='gray')

        ax.legend((left, right), data.columns.tolist())
        #import pdb; pdb.set_trace()
        #ax.set_xlim(pos[0]-3*w, pos[-1]+3*w)
        #if bp:
            #ax.boxplot(data,notch=1,positions=pos,vert=1)
        return ax


    def _stripchart(self, data, pos, rlabels, ax=None, mean=False, median=False,
        width=None, discrete=True, bins=30):
        """Plot samples given in `data` as horizontal lines.

        :Kwargs:
            mean: plot mean of each dataset as a thicker line if True
            median: plot median of each dataset as a dot if True.
            width: Horizontal width of a single dataset plot.
        """
        def draw_lines(d, maxcount, hist, bin_edges, sides=None):
            if discrete:
                bin_edges = bin_edges[:-1]  # upper edges not needed
                hw = hist * w / (2.*maxcount)
            else:
                bin_edges = d
                hw = w / 2.

            ax.hlines(bin_edges, sides[0]*hw + p, sides[1]*hw + p, color='white')
            if mean:  # draws a longer black line
                ax.hlines(np.mean(d), sides[0]*2*w + p, sides[1]*2*w + p,
                    lw=2, color='black')
            if median:  # puts a white dot
                ax.plot(p, np.median(d), 'o', color='white', markeredgewidth=0)

        #data, pos = self._beanlike_setup(data, ax)

        if width:
            w = width
        else:
            #if pos is None:
                #pos = [0,1]
            dist = np.max(pos)-np.min(pos)
            w = min(0.15*max(dist,1.0),0.5) * .5

        #colnames = [d for d in data.columns.names if d.startswith('cols.') ]
        #if len(colnames) == 0:  # nothing specified explicitly as a columns
            #try:
                #colnames = data.columns.levels[-1]
            #except:
                #colnames = data.columns

        #func1d = lambda x: np.histogram(x, bins=bins)
        # apply along cols
        hist, bin_edges = np.apply_along_axis(np.histogram, 0, data, bins)
        # it return arrays of object type, so we got to correct that
        hist = np.array(hist.tolist())
        bin_edges = np.array(bin_edges.tolist())
        maxcount = np.max(hist)

        for n, rlabel in enumerate(rlabels):
            p = pos[n]
            d = data.ix[:,rlabel]
            if len(d.columns) == 1:
                draw_lines(d.ix[:,0], maxcount, hist[0], bin_edges[0], sides=[-1,1])
            else:
                draw_lines(d.ix[:,0], maxcount, hist[0], bin_edges[0], sides=[-1,0])
                draw_lines(d.ix[:,1], maxcount, hist[1], bin_edges[1], sides=[ 0,1])

        ax.set_xlim(min(pos)-3*w, max(pos)+3*w)
        #ax.set_xticks([-1]+pos+[1])
        ax.set_xticks(pos)
        #import pdb; pdb.set_trace()
        #ax.set_xticklabels(['-1']+np.array(data.major_axis).tolist()+['1'])
        if len(rlabels) > 1:
            ax.set_xticklabels(rlabels)
        else:
            ax.set_xticklabels('')

        return ax


    def beanplot(self, data, ax=None, pos=None, mean=True, median=True, cut=None,
        order=None, discrete=True, **kwargs):
        """Make a bean plot of each dataset in the `data` sequence.

        Reference: http://www.jstatsoft.org/v28/c01/paper
        """

        data_tr, pos, rlabels = self._beanlike_setup(data, ax, order)

        dist = np.max(pos) - np.min(pos)
        w = min(0.15*max(dist,1.0),0.5) * .5
        ax = self._stripchart(data, pos, rlabels, ax=ax, mean=mean, median=median,
            width=0.8*w, discrete=discrete)
        ax = self._violinplot(data_tr, pos, rlabels, ax=ax, bp=False, cut=cut)

        return ax

    def _beanlike_setup(self, data, ax, order=None):
        data = pandas.DataFrame(data)  # Series will be forced into a DataFrame
        data = data.unstack([n for n in data.index.names if n.startswith('yerr.')])
        data = data.unstack([n for n in data.index.names if n.startswith('rows.')])
        rlabels = data.columns
        data = data.unstack([n for n in data.index.names if n.startswith('yerr.')])
        data = data.T  # now rows and values are in rows, cols in cols

        #if len(data.columns) > 2:
            #raise Exception('Beanplot cannot handle more than two categories')
        if len(data.index.levels[-1]) <= 1:
            raise Exception('Cannot make a beanplot for a single observation')

        ## put columns at the bottom so that it's easy to iterate in violinplot
        #order = {'rows': [], 'cols': []}
        #for i,n in enumerate(data.columns.names):
            #if n.startswith('cols.'):
                #order['cols'].append(i)
            #else:
                #order['rows'].append(i)
        #data = data.reorder_levels(order['rows'] + order['cols'], axis=1)

        if ax is None:
            ax = self.next()
        #if order is None:
        pos = range(len(rlabels))
        #else:
            #pos = np.lexsort((np.array(data.index).tolist(), order))

        return data, pos, rlabels