Beispiel #1
0
def getFigAx(subplot, name=None, title=None, figsize=None,  mpl=None, margins=None,
             sharex=None, sharey=None, AxesGrid=False, ngrids=None, direction='row',
             axes_pad = None, add_all=True, share_all=None, aspect=False,
             label_mode='L', cbar_mode=None, cbar_location='right',
             cbar_pad=None, cbar_size='5%', axes_class=None, lreduce=True): 
  # configure matplotlib
  warn('Deprecated function: use Figure or Axes class methods.')
  if mpl is None: import matplotlib as mpl
  elif isinstance(mpl,dict): mpl = loadMPL(**mpl) # there can be a mplrc, but also others
  elif not isinstance(mpl,ModuleType): raise TypeError
  from plotting.figure import MyFigure # prevent circular reference
  # figure out subplots
  if isinstance(subplot,(np.integer,int)):
    if subplot == 1: subplot = (1,1)
    elif subplot == 2: subplot = (1,2)
    elif subplot == 3: subplot = (1,3)
    elif subplot == 4: subplot = (2,2)
    elif subplot == 6: subplot = (2,3)
    elif subplot == 9: subplot = (3,3)
    else: raise NotImplementedError
  elif not (isinstance(subplot,(tuple,list)) and len(subplot) == 2) and all(isInt(subplot)): raise TypeError    
  # create figure
  if figsize is None: 
    if subplot == (1,1): figsize = (3.75,3.75)
    elif subplot == (1,2) or subplot == (1,3): figsize = (6.25,3.75)
    elif subplot == (2,1) or subplot == (3,1): figsize = (3.75,6.25)
    else: figsize = (6.25,6.25)
    #elif subplot == (2,2) or subplot == (3,3): figsize = (6.25,6.25)
    #else: raise NotImplementedError
  # figure out margins
  if margins is None:
    # N.B.: the rectangle definition is presumably left, bottom, width, height
    if subplot == (1,1): margins = (0.09,0.09,0.88,0.88)
    elif subplot == (1,2) or subplot == (1,3): margins = (0.06,0.1,0.92,0.87)
    elif subplot == (2,1) or subplot == (3,1): margins = (0.09,0.11,0.88,0.82)
    elif subplot == (2,2) or subplot == (3,3): margins = (0.055,0.055,0.925,0.925)
    else: margins = (0.09,0.11,0.88,0.82)
    #elif subplot == (2,2) or subplot == (3,3): margins = (0.09,0.11,0.88,0.82)
    #else: raise NotImplementedError    
    if title is not None: margins = margins[:3]+(margins[3]-0.03,) # make room for title
  if AxesGrid:
    if share_all is None: share_all = True
    if axes_pad is None: axes_pad = 0.05
    # create axes using the Axes Grid package
    fig = mpl.pylab.figure(facecolor='white', figsize=figsize, FigureClass=MyFigure)
    if axes_class is None:
      from plotting.axes import MyLocatableAxes  
      axes_class=(MyLocatableAxes,{})
    from mpl_toolkits.axes_grid1 import ImageGrid
    # AxesGrid: http://matplotlib.org/mpl_toolkits/axes_grid/users/overview.html
    grid = ImageGrid(fig, margins, nrows_ncols = subplot, ngrids=ngrids, direction=direction, 
                     axes_pad=axes_pad, add_all=add_all, share_all=share_all, aspect=aspect, 
                     label_mode=label_mode, cbar_mode=cbar_mode, cbar_location=cbar_location, 
                     cbar_pad=cbar_pad, cbar_size=cbar_size, axes_class=axes_class)
    # return figure and axes
    axes = tuple([ax for ax in grid]) # this is already flattened
    if lreduce and len(axes) == 1: axes = axes[0] # return a bare axes instance, if there is only one axes    
  else:
    # create axes using normal subplot routine
    if axes_pad is None: axes_pad = 0.03
    wspace = hspace = axes_pad
    if share_all: 
      sharex='all'; sharey='all'
    if sharex is True or sharex is None: sharex = 'col' # default
    if sharey is True or sharey is None: sharey = 'row'
    if sharex: hspace -= 0.015
    if sharey: wspace -= 0.015
    # create figure
    from matplotlib.pyplot import subplots    
    # GridSpec: http://matplotlib.org/users/gridspec.html 
    fig, axes = subplots(subplot[0], subplot[1], sharex=sharex, sharey=sharey,
                         squeeze=lreduce, facecolor='white', figsize=figsize, FigureClass=MyFigure)    
    # there is also a subplot_kw=dict() and fig_kw=dict()
    # just adjust margins
    margin_dict = dict(left=margins[0], bottom=margins[1], right=margins[0]+margins[2], 
                       top=margins[1]+margins[3], wspace=wspace, hspace=hspace)
    fig.subplots_adjust(**margin_dict)
  # add figure title
  if name is not None: fig.canvas.set_window_title(name) # window title
  if title is not None: fig.suptitle(title) # title on figure (printable)
  # return Figure/ImageGrid and tuple of axes
  #if AxesGrid: fig = grid # return ImageGrid instead of figure
  return fig, axes
Beispiel #2
0
def draw_time_series(results,
                     times,
                     labels,
                     fname,
                     fmt='png',
                     gridshape=(1, 1),
                     xlabel='',
                     ylabel='',
                     ptitle='',
                     subtitles=None,
                     label_month=False,
                     yscale='linear',
                     aspect=None):
    '''
    Purpose::
        Function to draw a time series plot

    Input::
        results - a 3d array of time series
        times - a list of python datetime objects
        labels - a list of strings with the names of each set of data
        fname - a string specifying the filename of the plot
        fmt - an optional string specifying the output filetype
        gridshape - optional tuple denoting the desired grid shape (nrows, ncols) for arranging
                    the subplots.
        xlabel - a string specifying the x-axis title
        ylabel - a string specifying the y-axis title
        ptitle - a string specifying the plot title
        subtitles - an optional list of strings specifying the title for each subplot
        label_month - optional bool to toggle drawing month labels
        yscale - optional string for setting the y-axis scale, 'linear' for linear
                 and 'log' for log base 10.
        aspect - Float denoting approximate aspect ratio of each subplot
                 (width / height). Default is 8.5 / 5.5
    '''
    # Handle the single plot case.
    if results.ndim == 2:
        results = results.reshape(1, *results.shape)

    # Make sure gridshape is compatible with input data
    nplots = results.shape[0]
    gridshape = _best_grid_shape(nplots, gridshape)

    # Set up the figure
    width, height = _fig_size(gridshape)
    fig = plt.figure()
    fig.set_size_inches((width, height))
    fig.dpi = 300

    # Make the subplot grid
    grid = ImageGrid(fig,
                     111,
                     nrows_ncols=gridshape,
                     axes_pad=0.3,
                     share_all=True,
                     add_all=True,
                     ngrids=nplots,
                     label_mode='L',
                     aspect=False,
                     cbar_mode='single',
                     cbar_location='bottom',
                     cbar_size=.05,
                     cbar_pad=.20)

    # Make the plots
    for i, ax in enumerate(grid):
        data = results[i]
        if label_month:
            xfmt = mpl.dates.DateFormatter('%b')
            xloc = mpl.dates.MonthLocator()
            ax.xaxis.set_major_formatter(xfmt)
            ax.xaxis.set_major_locator(xloc)

        # Set the y-axis scale
        ax.set_yscale(yscale)

        # Set up list of lines for legend
        lines = []
        ymin, ymax = 0, 0

        # Plot each line
        for tSeries in data:
            line = ax.plot_date(times, tSeries, '')
            lines.extend(line)
            cmin, cmax = tSeries.min(), tSeries.max()
            ymin = min(ymin, cmin)
            ymax = max(ymax, cmax)

        # Add a bit of padding so lines don't touch bottom and top of the plot
        ymin = ymin - ((ymax - ymin) * 0.1)
        ymax = ymax + ((ymax - ymin) * 0.1)
        ax.set_ylim((ymin, ymax))

        # Set the subplot title if desired
        if subtitles is not None:
            ax.set_title(subtitles[i], fontsize='small')

    # Create a master axes rectangle for figure wide labels
    fax = fig.add_subplot(111, frameon=False)
    fax.tick_params(labelcolor='none',
                    top='off',
                    bottom='off',
                    left='off',
                    right='off')
    fax.set_ylabel(ylabel)
    fax.set_title(ptitle, fontsize=16)
    fax.title.set_y(1.04)

    # Create the legend using a 'fake' colorbar axes. This lets us have a nice
    # legend that is in sync with the subplot grid
    cax = ax.cax
    cax.set_frame_on(False)
    cax.set_xticks([])
    cax.set_yticks([])
    cax.legend((lines),
               labels,
               loc='upper center',
               ncol=10,
               fontsize='small',
               mode='expand',
               frameon=False)

    # Note that due to weird behavior by axes_grid, it is more convenient to
    # place the x-axis label relative to the colorbar axes instead of the
    # master axes rectangle.
    cax.set_title(xlabel, fontsize=12)
    cax.title.set_y(-1.5)

    # Rotate the x-axis tick labels
    for ax in grid:
        for xtick in ax.get_xticklabels():
            xtick.set_ha('right')
            xtick.set_rotation(30)

    # Save the figure
    fig.savefig('%s.%s' % (fname, fmt), bbox_inches='tight', dpi=fig.dpi)
    fig.clf()
Beispiel #3
0
def draw_portrait_diagram(results,
                          rowlabels,
                          collabels,
                          fname,
                          fmt='png',
                          gridshape=(1, 1),
                          xlabel='',
                          ylabel='',
                          clabel='',
                          ptitle='',
                          subtitles=None,
                          cmap=None,
                          clevs=None,
                          nlevs=10,
                          extend='neither',
                          aspect=None):
    '''
    Purpose::
        Makes a portrait diagram plot.

    Input::
        results - 3d array of the field to be plotted. The second dimension
                  should correspond to the number of rows in the diagram and the
                  third should correspond to the number of columns.
        rowlabels - a list of strings denoting labels for each row
        collabels - a list of strings denoting labels for each column
        fname - a string specifying the filename of the plot
        fmt - an optional string specifying the output filetype
        gridshape - optional tuple denoting the desired grid shape (nrows, ncols) for arranging
                    the subplots.
        xlabel - an optional string specifying the x-axis title
        ylabel - an optional string specifying the y-axis title
        clabel - an optional string specifying the colorbar title
        ptitle - a string specifying the plot title
        subtitles - an optional list of strings specifying the title for each subplot
        cmap - an optional string or matplotlib.colors.LinearSegmentedColormap instance
               denoting the colormap
        clevs - an optional list of ints or floats specifying colorbar levels
        nlevs - an optional integer specifying the target number of contour levels if
                clevs is None
        extend - an optional string to toggle whether to place arrows at the colorbar
             boundaries. Default is 'neither', but can also be 'min', 'max', or
             'both'. Will be automatically set to 'both' if clevs is None.
        aspect - Float denoting approximate aspect ratio of each subplot
                 (width / height). Default is 8.5 / 5.5
    '''
    # Handle the single plot case.
    if results.ndim == 2:
        results = results.reshape(1, *results.shape)

    nplots = results.shape[0]

    # Make sure gridshape is compatible with input data
    gridshape = _best_grid_shape(nplots, gridshape)

    # Row and Column labels must be consistent with the shape of
    # the input data too
    prows, pcols = results.shape[1:]
    if len(rowlabels) != prows or len(collabels) != pcols:
        raise ValueError(
            'rowlabels and collabels must have %d and %d elements respectively'
            % (prows, pcols))

    # Set up the figure
    width, height = _fig_size(gridshape)
    fig = plt.figure()
    fig.set_size_inches((width, height))
    fig.dpi = 300

    # Make the subplot grid
    grid = ImageGrid(fig,
                     111,
                     nrows_ncols=gridshape,
                     axes_pad=0.4,
                     share_all=True,
                     aspect=False,
                     add_all=True,
                     ngrids=nplots,
                     label_mode='all',
                     cbar_mode='single',
                     cbar_location='bottom',
                     cbar_size=.15,
                     cbar_pad='3%')

    # Calculate colorbar levels if not given
    if clevs is None:
        # Cut off the tails of the distribution
        # for more representative colorbar levels
        clevs = _nice_intervals(results, nlevs)
        extend = 'both'

    cmap = plt.get_cmap(cmap)
    norm = mpl.colors.BoundaryNorm(clevs, cmap.N)

    # Do the plotting
    for i, ax in enumerate(grid):
        data = results[i]
        cs = ax.matshow(data,
                        cmap=cmap,
                        aspect='auto',
                        origin='lower',
                        norm=norm)

        # Add grid lines
        ax.xaxis.set_ticks(np.arange(data.shape[1] + 1))
        ax.yaxis.set_ticks(np.arange(data.shape[0] + 1))
        x = (ax.xaxis.get_majorticklocs() - .5)
        y = (ax.yaxis.get_majorticklocs() - .5)
        ax.vlines(x, y.min(), y.max())
        ax.hlines(y, x.min(), x.max())

        # Configure ticks
        ax.xaxis.tick_bottom()
        ax.xaxis.set_ticks_position('none')
        ax.yaxis.set_ticks_position('none')
        ax.set_xticklabels(collabels, fontsize='xx-small')
        ax.set_yticklabels(rowlabels, fontsize='xx-small')

        # Add axes title
        if subtitles is not None:
            ax.text(0.5,
                    1.04,
                    subtitles[i],
                    va='center',
                    ha='center',
                    transform=ax.transAxes,
                    fontsize='small')

    # Create a master axes rectangle for figure wide labels
    fax = fig.add_subplot(111, frameon=False)
    fax.tick_params(labelcolor='none',
                    top='off',
                    bottom='off',
                    left='off',
                    right='off')
    fax.set_ylabel(ylabel)
    fax.set_title(ptitle, fontsize=16)
    fax.title.set_y(1.04)

    # Add colorbar
    cax = ax.cax
    cbar = fig.colorbar(cs,
                        cax=cax,
                        norm=norm,
                        boundaries=clevs,
                        drawedges=True,
                        extend=extend,
                        orientation='horizontal',
                        extendfrac='auto')
    cbar.set_label(clabel)
    cbar.set_ticks(clevs)
    cbar.ax.xaxis.set_ticks_position('none')
    cbar.ax.yaxis.set_ticks_position('none')

    # Note that due to weird behavior by axes_grid, it is more convenient to
    # place the x-axis label relative to the colorbar axes instead of the
    # master axes rectangle.
    cax.set_title(xlabel, fontsize=12)
    cax.title.set_y(1.5)

    # Save the figure
    fig.savefig('%s.%s' % (fname, fmt), bbox_inches='tight', dpi=fig.dpi)
    fig.clf()
Beispiel #4
0
def evaluate_during_training(model, test_sample, ode_integration, time_history,
                             progress_ep, progress_bat):
    x_encoded = model.encode(test_sample, 0)
    z_0 = model.reparameterize(x_encoded)
    latent_states_ode, _ = model.latent_trajectory(z_0, tf.zeros(batch_size),
                                                   ode_integration)
    x_rec = np.array(
        tf.stack(
            [model.decode(latent_states_ode[i, :, 0]) for i in range(frames)],
            axis=3))

    gamma = progress_ep / epochs + progress_bat / (batches * epochs)
    mean = tf.keras.metrics.Mean()
    mean(compute_loss(model, test_sample, ode_integration, gamma))
    elbo = -mean.result() / batch_size
    mean(
        tf.reduce_sum(frames *
                      tf.keras.losses.binary_crossentropy(test_sample, x_rec)))
    rec = -mean.result() / batch_size

    meansqerr = tf.keras.metrics.MeanSquaredError()
    meansqerr(test_sample, x_rec)
    mse = meansqerr.result()

    if progress_bat == 0:
        return int(elbo), int(rec), mse, '-', '--'
    elif progress_bat < 10:
        a = time_history[progress_ep, progress_bat]
        b = time_history[progress_ep, 0]
        min, sec = divmod(
            (a - b) / (progress_bat) * (batches - progress_bat - 1), 60)
        return int(elbo), int(rec), mse, int(min), int(sec)
    elif progress_bat != batches - 1:
        a = time_history[progress_ep, progress_bat]
        b = time_history[progress_ep, progress_bat - 10]
        min, sec = divmod((a - b) / 10 * (batches - progress_bat - 1), 60)
        return int(elbo), int(rec), mse, int(min), int(sec)
    else:
        fig = plt.figure(figsize=(8, 8))
        index = np.random.randint(batch_size, size=5)
        grid = ImageGrid(
            fig,
            111,
            nrows_ncols=(2 * 5, frames),
            axes_pad=0.1,
        )
        plot = np.zeros((2 * 5 * frames, 28, 28))
        for j, i in enumerate(index):
            plot[j * 2 * frames:(j * 2 + 1) * frames] = np.transpose(
                np.array(test_sample[i]), (2, 0, 1))
            plot[(j * 2 + 1) * frames:(j * 2 + 2) * frames] = np.transpose(
                x_rec[i], (2, 0, 1))
        for ax, im in zip(grid, plot):
            plt.gray()
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)
            ax.imshow(im)
        plt.show()

        a = time_history[progress_ep, progress_bat - 1]
        b = time_history[progress_ep, 0]
        min, sec = divmod((a - b) * (epochs - progress_ep - 1), 60)
        min_el, sec_el = divmod((a - b), 60)
        return int(min), int(sec), int(min_el), int(sec_el)
Beispiel #5
0
def plotImageGrid(images, nrows_ncols=None, extent=None, clim=None, interpolation='none',
                  cmap='gray', imScale=2., cbar=True, titles=None, titlecol=['r', 'y'],
                  same_zscale=False, **kwds):
    import matplotlib.pyplot as plt
    import matplotlib
    matplotlib.style.use('ggplot')
    from mpl_toolkits.axes_grid1 import ImageGrid
    from matplotlib.offsetbox import AnchoredText
    from matplotlib.patheffects import withStroke

    def add_inner_title(ax, title, loc, size=None, **kwargs):
        if size is None:
            size = dict(size=plt.rcParams['legend.fontsize'], color=titlecol[0])
        at = AnchoredText(title, loc=loc, prop=size,
                          pad=0., borderpad=0.5,
                          frameon=False, **kwargs)
        ax.add_artist(at)
        at.txt._text.set_path_effects([withStroke(foreground=titlecol[1], linewidth=3)])
        return at

    if nrows_ncols is None:
        tmp = np.int(np.floor(np.sqrt(len(images))))
        nrows_ncols = (tmp, np.int(np.ceil(np.float(len(images))/tmp)))
    if nrows_ncols[0] <= 0:
        nrows_ncols[0] = 1
    if nrows_ncols[1] <= 0:
        nrows_ncols[1] = 1
    size = (nrows_ncols[1]*imScale, nrows_ncols[0]*imScale)
    fig = plt.figure(1, size)
    igrid = ImageGrid(fig, 111,  # similar to subplot(111)
                      nrows_ncols=nrows_ncols, direction='row',  # creates 2x2 grid of axes
                      axes_pad=0.1,  # pad between axes in inch.
                      label_mode="L",  # share_all=True,
                      cbar_location="right", cbar_mode="single", cbar_size='3%')
    extentWasNone = False
    clim_orig = clim

    imagesToPlot = []

    for i in range(len(images)):
        ii = images[i]
        if hasattr(ii, 'computeImage'):
            if hasattr(ii, 'getDimensions'):
                img = afwImage.ImageD(ii.getDimensions())
                ii.computeImage(img, doNormalize=False)
                ii = img
            else:
                ii = ii.computeImage()
        if hasattr(ii, 'computeKernelImage'):
            if hasattr(ii, 'getDimensions'):
                img = afwImage.ImageD(ii.getDimensions())
                ii.computeKernelImage(img, doNormalize=False)
                ii = img
            else:
                ii = ii.computeImage()
        if hasattr(ii, 'getImage'):
            ii = ii.getImage()
        if hasattr(ii, 'getMaskedImage'):
            ii = ii.getMaskedImage().getImage()
        if hasattr(ii, 'getArray'):
            bbox = ii.getBBox()
            if extent is None:
                extentWasNone = True
                extent = (bbox.getBeginX(), bbox.getEndX(), bbox.getBeginY(), bbox.getEndY())
            ii = ii.getArray()
        #if extent is not None and not extentWasNone:
        #    ii = ii[extent[0]:extent[1], extent[2]:extent[3]]

        imagesToPlot.append(ii)

    if clim_orig is None and same_zscale:
        tmp_im = [iii.flatten() for iii in imagesToPlot]
        tmp_im = np.concatenate(tmp_im)
        clim = clim_orig = zscale_image(tmp_im)
        del tmp_im

    for i in range(len(imagesToPlot)):
        ii = imagesToPlot[i]
        if clim_orig is None:
            clim = zscale_image(ii)
        if cbar and clim_orig is not None:
            ii = np.clip(ii, clim[0], clim[1])
        if np.isclose(clim[0], clim[1]):
            clim = (clim[0], clim[1] + clim[0] / 10.)  # in case there's nothing in the image
        if np.isclose(clim[0], clim[1]):
            clim = (clim[0] - 0.1, clim[1] + 0.1)  # in case there's nothing in the image
        im = igrid[i].imshow(ii, origin='lower', interpolation=interpolation, cmap=cmap,
                             extent=extent, clim=clim, **kwds)
        if cbar:
            igrid[i].cax.colorbar(im)
        if titles is not None:  # assume titles is an array or tuple of same length as images.
            t = add_inner_title(igrid[i], titles[i], loc=2)
            t.patch.set_ec("none")
            t.patch.set_alpha(0.5)
        if extentWasNone:
            extent = None
        extentWasNone = False
    return igrid
Beispiel #6
0
def diagnostic_plots(model,
                     chainfile,
                     burnin=0.3,
                     channelmaps=True,
                     cornerplot=True,
                     momentmaps=True,
                     bestarray=False,
                     outname='model',
                     vrangecm=None,
                     vrangemom=None,
                     cmap0='RdYlBu_r',
                     cmap1='Spectral_r',
                     cmap2='copper_r',
                     maskvalue=3.0,
                     **kwargs):
    """ This program will create several diagnostic plots of a given model
    and MCMC chain. The plots are not paper-ready but are meant to understand
    how well the MCMC chain ran. Currently the following plots are generated:

    Channel maps of the data, model and residual as well as a composite file
    which shows the data with the residual overplotted as contours
    A corner plot of the given MCMC chain converted to the correct units
    (uses the corner package)
    Moment-0, 1 and 2 images of the data and the model. For the moment 0, the
    residual contours are also plotted in the moment0-data.

    input and keywords:
    model:       This is a Qube object that contains a defined model.
    chainfile:   This is a file containting an MCMC chain as a numpy.save
                 object. The shape of the object is (nwalkers, nruns, dim+1)
                 where dim is the number of variable parameters in the MCMC
                 chain and the extra slice contains the lnprobability of the
                 given parameters in that link.
    burnin:      burnin fraction of the chain (default = 30%, i.e., 0.3)
    channelmaps: If true will plot the channels maps of the data, model and
                 residual
    cornerplot:  If true will plot the corner plot of the chain file
    momentmaps:  If true will plot the zeroth, first and second moment of the
                 data and the model
    bestarray:   If true the plots will be generated from the model with the
                 highest probability if false the median parameters will be
                 chosen.
    outname:     The name of the root used to save the pdf figures
    vrangecm:    z-axis range used to plot the channelmaps
    vrangemom:   z-axis range used to plot the moment-0 channel maps
    cmap0:       colormap to use for the plotting of the moment-0 maps
    cmap1:       colormap to use for the plotting of the moment-1 maps
    cmap2:       colormap to use for the plotting of the moment-2 maps
    maskvalue:   value (in sigma) to use to generate the mask in the
                 moment-0 image which is used in higher moment images.
                 default is 3.0.
    """

    # read in the chain data file and load it in the model then regenerate
    # the model with the wanted values (either median or best)
    Chain = np.load(chainfile)
    Chain = Chain[:, int(burnin * Chain.shape[1]):, :-1]
    Chain = Chain.reshape((-1, Chain.shape[2]))
    model.get_chainresults(chainfile, burnin=burnin)
    if not bestarray:
        model.update_parameters(model.chainpar['MedianArray'])
    else:
        model.update_parameters(model.chainpar['BestArray'])
    model.create_model()

    # create the data, model and residual cubes
    dqube = dc(model)
    mqube = dc(model)
    mqube.data = model.model
    rqube = dc(model)
    rqube.data = model.data - model.model

    # make the corner plot
    if cornerplot is True:
        # convert each parameter to a physically meaningful quantity
        # units = []
        for idx, key in enumerate(model.mcmcmap):

            # get conversion factors and units for each key
            if model.initpar[key]['Conversion'] is not None:
                conversion = model.initpar[key]['Conversion'].value
            else:
                conversion = 1.
            # units.append(model.initpar[key]['Unit'].to_string())
            # quick fix for IO
            if key == 'I0':
                Chain[:, idx] = Chain[:, idx] * 1E3

            Chain[:, idx] = Chain[:, idx] * conversion
        corner.corner(Chain,
                      labels=model.mcmcmap,
                      quantiles=[0.16, 0.5, 0.84],
                      show_titles=True)

        plt.savefig(outname + '_cornerplot.pdf', format='pdf', dpi=300)
        plt.close()

    # make the channel maps
    if channelmaps is True:
        # define some global properties for all plots
        if vrangecm is None:
            vrangecm = [np.nanmin(dqube.data), np.nanmax(dqube.data)]
        sigma = np.sqrt(model.variance[:, 0, 0])
        clevels = [np.insert(np.arange(3, 30, 3), 0, -3) * i for i in sigma]

        # make the channel map for the data
        create_channelmap(raster=dqube,
                          contour=dqube,
                          clevels=clevels,
                          pdfname=outname + '_datachannelmap.pdf',
                          vrange=vrangecm,
                          **kwargs)

        # make the channel map for the model
        create_channelmap(raster=mqube,
                          contour=mqube,
                          clevels=clevels,
                          pdfname=outname + '_modelchannelmap.pdf',
                          vrange=vrangecm,
                          **kwargs)

        # make the channel map for the residual
        create_channelmap(raster=rqube,
                          contour=rqube,
                          clevels=clevels,
                          pdfname=outname + '_residualchannelmap.pdf',
                          vrange=vrangecm,
                          **kwargs)

        # make the channel map for the data with residual contours
        create_channelmap(raster=dqube,
                          contour=rqube,
                          clevels=clevels,
                          pdfname=outname + '_combinedchannelmap.pdf',
                          vrange=vrangecm,
                          **kwargs)

        plt.close()

    # make the moment maps
    if momentmaps is True:
        # create the moment-0 images
        dMom0 = dqube.calculate_moment(moment=0)
        mMom0 = mqube.calculate_moment(moment=0)
        rMom0 = rqube.calculate_moment(moment=0)

        # calculate the Mom0sig of the data and create the contour levels
        Mom0sig = (np.sqrt(np.nansum(model.variance[:, 0, 0])) *
                   model.__get_velocitywidth__())
        clevels = np.insert(np.arange(3, 30, 3), 0, -3) * Mom0sig
        if vrangemom is None:
            vrangemom = [-3 * Mom0sig, 11 * Mom0sig]
        mask = dMom0.mask_region(value=Mom0sig * maskvalue, applymask=False)

        # create the figure
        fig = plt.figure(1, (8., 8.))
        grid = ImageGrid(fig,
                         111,
                         nrows_ncols=(2, 2),
                         axes_pad=0.,
                         cbar_mode='single',
                         cbar_location='right')

        # plot the figures
        standardfig(raster=dMom0,
                    contour=dMom0,
                    clevels=clevels,
                    ax=grid[0],
                    fig=fig,
                    vrange=vrangemom,
                    cmap=cmap0,
                    text='Data',
                    textprop=[dict(size=12)],
                    **kwargs)
        standardfig(raster=mMom0,
                    contour=mMom0,
                    clevels=clevels,
                    ax=grid[1],
                    fig=fig,
                    vrange=vrangemom,
                    cmap=cmap0,
                    beam=False,
                    text='Model',
                    textprop=[dict(size=12)],
                    **kwargs)
        standardfig(raster=rMom0,
                    contour=rMom0,
                    clevels=clevels,
                    ax=grid[2],
                    fig=fig,
                    vrange=vrangemom,
                    cmap=cmap0,
                    text='Residual',
                    textprop=[dict(size=12)],
                    **kwargs)
        standardfig(raster=dMom0,
                    contour=rMom0,
                    clevels=clevels,
                    ax=grid[3],
                    fig=fig,
                    vrange=vrangemom,
                    cmap=cmap0,
                    text='Data with residual contours',
                    textprop=[dict(size=12)],
                    **kwargs)

        # plot the color bar
        norm = mpl.colors.Normalize(vmin=vrangemom[0], vmax=vrangemom[1])
        cmapo = plt.cm.ScalarMappable(cmap=cmap0, norm=norm)
        cmapo.set_array([])
        cbr = plt.colorbar(cmapo, cax=grid.cbar_axes[0])
        cbr.ax.set_ylabel('Moment-0', labelpad=-1)
        plt.savefig(outname + '_moment0.pdf', format='pdf', dpi=300)
        plt.close()

        # create the moment-1 images
        dsqube = dqube.mask_region(mask=mask)
        dMom1 = dsqube.calculate_moment(moment=1)
        msqube = mqube.mask_region(mask=mask)
        mMom1 = msqube.calculate_moment(moment=1)

        # create the figure
        fig = plt.figure(1, (8., 5.))
        grid = ImageGrid(fig,
                         111,
                         nrows_ncols=(1, 2),
                         axes_pad=0.,
                         cbar_mode='single',
                         cbar_location='right')

        # plot the figures
        vrangemom1 = [np.nanmin(dMom1.data), np.nanmax(dMom1.data)]
        standardfig(raster=dMom1,
                    ax=grid[0],
                    fig=fig,
                    cmap=cmap1,
                    vrange=vrangemom1,
                    text='Data',
                    textprop=[dict(size=12)],
                    **kwargs)
        standardfig(raster=mMom1,
                    ax=grid[1],
                    fig=fig,
                    cmap=cmap1,
                    vrange=vrangemom1,
                    beam=False,
                    text='Model',
                    textprop=[dict(size=12)],
                    **kwargs)

        # plot the color bar
        norm = mpl.colors.Normalize(vmin=vrangemom1[0], vmax=vrangemom1[1])
        cmapo = plt.cm.ScalarMappable(cmap=cmap1, norm=norm)
        cmapo.set_array([])
        cbr = plt.colorbar(cmapo, cax=grid.cbar_axes[0])
        cbr.ax.set_ylabel('Moment-1', labelpad=-1)

        plt.savefig(outname + '_moment1.pdf', format='pdf', dpi=300)
        plt.close()

        # create the moment-2 images
        dsqube = dqube.mask_region(mask=mask)
        dMom2 = dsqube.calculate_moment(moment=2)
        msqube = mqube.mask_region(mask=mask)
        mMom2 = msqube.calculate_moment(moment=2)

        # create the figure
        fig = plt.figure(1, (5., 8.))
        grid = ImageGrid(fig,
                         111,
                         nrows_ncols=(1, 2),
                         axes_pad=0.,
                         cbar_mode='single',
                         cbar_location='right')

        # plot the figures
        vrangemom2 = [np.nanmin(dMom2.data), np.nanmax(dMom2.data)]
        standardfig(raster=dMom2,
                    ax=grid[0],
                    fig=fig,
                    cmap=cmap2,
                    vrange=vrangemom2,
                    text='Data',
                    textprop=[dict(size=12)],
                    **kwargs)
        standardfig(raster=mMom2,
                    ax=grid[1],
                    fig=fig,
                    cmap=cmap2,
                    vrange=vrangemom2,
                    beam=False,
                    text='Model',
                    textprop=[dict(size=12)],
                    **kwargs)

        # plot the color bar
        norm = mpl.colors.Normalize(vmin=vrangemom2[0], vmax=vrangemom2[1])
        cmapo = plt.cm.ScalarMappable(cmap=cmap2, norm=norm)
        cmapo.set_array([])
        cbr = plt.colorbar(cmapo, cax=grid.cbar_axes[0])
        cbr.ax.set_ylabel('Moment-2', labelpad=-1)

        plt.savefig(outname + '_moment2.pdf', format='pdf', dpi=300)
        plt.close()
Beispiel #7
0
def doplot(prefix='/media/scratch/peri/does_matter/brownian-motion',
           snrs=[20, 50, 200, 500]):
    fig = pl.figure(figsize=(14, 7))

    ax = fig.add_axes([0.43, 0.15, 0.52, 0.75])
    gs = ImageGrid(fig,
                   rect=[0.05, 0.05, 0.25, 0.90],
                   nrows_ncols=(2, 1),
                   axes_pad=0.25,
                   cbar_location='right',
                   cbar_mode='each',
                   cbar_size='10%',
                   cbar_pad=0.04)

    s, im, pos = diffusion(1.0, 0.1, samples=200)
    h, l = runner.do_samples(s, 30, 0, quiet=True)
    nn = np.s_[:, :, im.shape[2] / 2]

    figlbl, labels = ['A', 'B'], ['Reference', 'Difference']
    diff = (im - s.get_model_image()[s.inner])[nn]
    diffm = 0.1  #np.abs(diff).max()
    im0 = gs[0].imshow(im[nn], vmin=0, vmax=1, cmap='bone_r')
    im1 = gs[1].imshow(diff, vmin=-diffm, vmax=diffm, cmap='RdBu')
    cb0 = pl.colorbar(im0, cax=gs[0].cax, ticks=[0, 1])
    cb1 = pl.colorbar(im1, cax=gs[1].cax, ticks=[-diffm, diffm])
    cb0.ax.set_yticklabels(['0', '1'])
    cb1.ax.set_yticklabels(['-%0.1f' % diffm, '%0.1f' % diffm])

    for i in xrange(2):
        gs[i].set_xticks([])
        gs[i].set_yticks([])
        gs[i].set_ylabel(labels[i])
        #lbl(gs[i], figlbl[i])

    aD = 1.0 / (25. / 0.15)

    symbols = ['o', '^', 'D', '>']
    for i, snr in enumerate(snrs):
        c = common.COLORS[i]
        fn = prefix + '-snr-' + str(snr) + '.pkl'
        crb, val, err, pos, time = pickle.load(open(fn))

        if i == 0:
            label0 = r"$\rm{SNR} = %i$ CRB" % snr
            label1 = r"$\rm{SNR} = %i$ Error" % snr
        else:
            label0 = r"$%i$, CRB" % snr
            label1 = r"$%i$, Error" % snr

        time *= aD  # a^2/D, where D=1, and a=5 (see first function)
        ax.plot(time, common.dist(crb), '-', c=c, lw=3, label=label0)
        ax.plot(time,
                common.errs(val, pos),
                symbols[i],
                ls='--',
                lw=2,
                c=c,
                label=label1,
                ms=12)

    # 80% glycerol value
    ax.vlines(0.100 * aD,
              1e-6,
              100,
              linestyle='-',
              lw=40,
              alpha=0.2,
              color='k')
    #pl.text(0.116*aD*1.45, 3e-4, 'G/W')

    # 100% water value
    ax.vlines(0.100 * aD * 60,
              1e-6,
              100,
              linestyle='-',
              lw=40,
              alpha=0.2,
              color='b')
    #ax.text(0.116*aD*75*2, 0.5, 'W')

    ax.loglog()
    ax.set_ylim(5e-4, 2e0)
    ax.set_xlim(time[0], time[-1])
    ax.legend(loc='best', ncol=2, prop={'size': 18}, numpoints=1)
    ax.set_xlabel(r"$\tau_{\rm{exposure}} / (a^2/D)$")
    ax.set_ylabel(r"Position CRB, Error")
    ax.grid(False, which='both', axis='both')
    ax.set_title("Brownian motion")
Beispiel #8
0
def plot_tsne_selection_grid(z_pos,
                             x_pos,
                             z_neg,
                             vmin,
                             vmax,
                             fig_path,
                             labels=None,
                             fig_size=(9, 9),
                             g_j=7,
                             s=.5,
                             suffix='png'):
    ncol = x_pos.shape[1]
    g_i = ncol // g_j if (ncol % g_j == 0) else ncol // g_j + 1
    if labels is None:
        labels = [str(a) for a in np.range(ncol)]

    fig = plt.figure(figsize=fig_size)
    fig.clf()
    grid = ImageGrid(
        fig,
        111,
        nrows_ncols=(g_i, g_j),
        ngrids=ncol,
        aspect=True,
        direction="row",
        axes_pad=(0.15, 0.5),
        add_all=True,
        label_mode="1",
        share_all=True,
        cbar_location="top",
        cbar_mode="each",
        cbar_size="8%",
        cbar_pad="5%",
    )
    for seq_index in range(ncol):
        ax = grid[seq_index]
        ax.text(0,
                .92,
                labels[seq_index],
                horizontalalignment='center',
                transform=ax.transAxes,
                size=20,
                weight='bold')
        a = x_pos[:, seq_index]
        ax.scatter(z_neg[:, 0],
                   z_neg[:, 1],
                   s=s,
                   marker='o',
                   c='lightgray',
                   alpha=0.5,
                   edgecolors='face')
        im = ax.scatter(z_pos[:, 0],
                        z_pos[:, 1],
                        s=s,
                        marker='o',
                        c=a,
                        cmap=cm.jet,
                        edgecolors='face',
                        vmin=vmin[seq_index],
                        vmax=vmax[seq_index])
        ax.cax.colorbar(im)
        clean_axis(ax)
        ax.grid(False)
    plt.savefig('.'.join([fig_path, suffix]), format=suffix)
    plt.clf()
    plt.close()
def event_display(config):
    config.input_file = config.input_file[0]
    config.output_dir += ('' if config.output_dir.endswith('/') else '/')
    if not os.path.isdir(config.output_dir):
        os.mkdir(config.output_dir)
    print "Reading request from: " + str(config.input_file)
    print "output directory: " + str(config.output_dir)

    wl = open(config.input_file, 'r')
    lines = wl.readlines()
    for line in lines:
        splits = line.split()
        softmax = splits[0].strip()
        input_file = splits[1].strip()
        ev = int(splits[2].strip())

        print "now processing " + input_file + " at index " + str(ev)

        event_class = get_class(input_file)
        write_dir = config.output_dir + event_class + "_softmax" + str(
            softmax).split('.')[0] + '_' + str(softmax).split('.')[1] + "/"
        if not os.path.isdir(write_dir):
            os.mkdir(write_dir)

        norm = plt.Normalize()
        cm = matplotlib.cm.plasma
        cmaplist = [cm(i) for i in range(cm.N)]
        cm_cat_pmt_in_module = lsc.from_list('Custom cmap', cmaplist, cm.N)
        bounds_cat_pmt_in_module = np.linspace(0, 19, 20)
        norm_cat_pmt_in_module = matplotlib.colors.BoundaryNorm(
            bounds_cat_pmt_in_module, cm_cat_pmt_in_module.N)

        cm_cat_module_row = lsc.from_list('Custom cmap', cmaplist, cm.N)
        bounds_cat_module_row = np.linspace(0, 16, 17)
        norm_cat_module_row = matplotlib.colors.BoundaryNorm(
            bounds_cat_module_row, cm_cat_module_row.N)

        cm_cat_module_col = lsc.from_list('Custom cmap', cmaplist, cm.N)
        bounds_cat_module_col = np.linspace(0, 40, 41)
        norm_cat_module_col = matplotlib.colors.BoundaryNorm(
            bounds_cat_module_col, cm_cat_module_col.N)

        file = ROOT.TFile(input_file, "read")

        label = -1
        if "_gamma" in input_file:
            label = 0
        elif "_e" in input_file:
            label = 1
        elif "_mu" in input_file:
            label = 2
        elif "_pi0" in input_file:
            label = 3
        else:
            print "Unknown input file particle type"
            sys.exit()

        tree = file.Get("wcsimT")

        nevent = tree.GetEntries()

        print "number of entries in the tree: " + str(nevent)

        geotree = file.Get("wcsimGeoT")

        print "number of entries in the geometry tree: " + str(
            geotree.GetEntries())

        geotree.GetEntry(0)
        geo = geotree.wcsimrootgeom

        num_pmts = geo.GetWCNumPMT()

        np_pos_x_all_tubes = np.zeros((num_pmts))
        np_pos_y_all_tubes = np.zeros((num_pmts))
        np_pos_z_all_tubes = np.zeros((num_pmts))
        np_pmt_in_module_id_all_tubes = np.zeros((num_pmts))
        np_pmt_index_all_tubes = np.arange(num_pmts)
        np.random.shuffle(np_pmt_index_all_tubes)
        np_module_index_all_tubes = module_index(np_pmt_index_all_tubes)

        for i in range(len(np_pmt_index_all_tubes)):

            pmt_tube_in_module_id = np_pmt_index_all_tubes[i] % 19
            np_pmt_in_module_id_all_tubes[i] = pmt_tube_in_module_id
            pmt = geo.GetPMT(np_pmt_index_all_tubes[i])

            np_pos_x_all_tubes[i] = pmt.GetPosition(2)
            np_pos_y_all_tubes[i] = pmt.GetPosition(0)
            np_pos_z_all_tubes[i] = pmt.GetPosition(1)

        np_pos_r_all_tubes = np.hypot(np_pos_x_all_tubes, np_pos_y_all_tubes)
        r_max = np.amax(np_pos_r_all_tubes)

        np_wall_indices_ad_hoc = np.unique(np_module_index_all_tubes[np.where(
            (np_pos_z_all_tubes < 499.0) & (np_pos_z_all_tubes > -499.0))[0]])
        np_bottom_indices_ad_hoc = np.unique(
            np_module_index_all_tubes[np.where(
                (np_pos_z_all_tubes < -499.0))[0]])
        np_top_indices_ad_hoc = np.unique(np_module_index_all_tubes[np.where(
            (np_pos_z_all_tubes > 499.0))[0]])

        np_pos_phi_all_tubes = np.arctan2(np_pos_y_all_tubes,
                                          np_pos_x_all_tubes)
        np_pos_arc_all_tubes = r_max * np_pos_phi_all_tubes

        np_wall_indices = np.where(is_barrel(np_module_index_all_tubes))
        np_top_indices = np.where(is_top(np_module_index_all_tubes))
        np_bottom_indices = np.where(is_bottom(np_module_index_all_tubes))

        np_pmt_in_module_id_wall_tubes = np_pmt_in_module_id_all_tubes[
            np_wall_indices]
        np_pmt_in_module_id_top_tubes = np_pmt_in_module_id_all_tubes[
            np_top_indices]
        np_pmt_in_module_id_bottom_tubes = np_pmt_in_module_id_all_tubes[
            np_bottom_indices]

        np_pos_x_wall_tubes = np_pos_x_all_tubes[np_wall_indices]
        np_pos_y_wall_tubes = np_pos_y_all_tubes[np_wall_indices]
        np_pos_z_wall_tubes = np_pos_z_all_tubes[np_wall_indices]

        np_pos_x_top_tubes = np_pos_x_all_tubes[np_top_indices]
        np_pos_y_top_tubes = np_pos_y_all_tubes[np_top_indices]
        np_pos_z_top_tubes = np_pos_z_all_tubes[np_top_indices]

        np_pos_x_bottom_tubes = np_pos_x_all_tubes[np_bottom_indices]
        np_pos_y_bottom_tubes = np_pos_y_all_tubes[np_bottom_indices]
        np_pos_z_bottom_tubes = np_pos_z_all_tubes[np_bottom_indices]

        np_wall_row, np_wall_col = row_col(
            np_module_index_all_tubes[np_wall_indices])

        np_pos_phi_wall_tubes = np_pos_phi_all_tubes[np_wall_indices]
        np_pos_arc_wall_tubes = np_pos_arc_all_tubes[np_wall_indices]

        fig101 = plt.figure(num=101, clear=True)
        fig101.set_size_inches(10, 8)
        ax101 = fig101.add_subplot(111)
        pos_arc_z_disp_all_tubes = ax101.scatter(
            np_pos_arc_all_tubes,
            np_pos_z_all_tubes,
            c=np_pmt_in_module_id_all_tubes,
            s=5,
            cmap=cm_cat_pmt_in_module,
            norm=norm_cat_pmt_in_module,
            marker='.')
        ax101.set_xlabel('arc along the wall')
        ax101.set_ylabel('z')
        cb_pos_arc_z_disp_all_tubes = fig101.colorbar(pos_arc_z_disp_all_tubes,
                                                      ticks=range(20),
                                                      pad=0.1)
        cb_pos_arc_z_disp_all_tubes.set_label("pmt in module")
        fig101.savefig(write_dir + "pos_arc_z_disp_all_tubes.pdf")

        fig102 = plt.figure(num=102, clear=True)
        fig102.set_size_inches(10, 8)
        ax102 = fig102.add_subplot(111)
        pos_x_y_disp_all_tubes = ax102.scatter(np_pos_x_all_tubes,
                                               np_pos_y_all_tubes,
                                               c=np_pmt_in_module_id_all_tubes,
                                               s=5,
                                               cmap=cm_cat_pmt_in_module,
                                               norm=norm_cat_pmt_in_module,
                                               marker='.')
        ax102.set_xlabel('x')
        ax102.set_ylabel('y')
        cb_pos_x_y_disp_all_tubes = fig102.colorbar(pos_x_y_disp_all_tubes,
                                                    ticks=range(20),
                                                    pad=0.1)
        cb_pos_x_y_disp_all_tubes.set_label("pmt in module")
        fig102.savefig(write_dir + "pos_x_y_disp_all_tubes.pdf")

        fig103 = plt.figure(num=103, clear=True)
        fig103.set_size_inches(10, 8)
        ax103 = fig103.add_subplot(111)
        pos_arc_z_disp_wall_tubes = ax103.scatter(
            np_pos_arc_wall_tubes,
            np_pos_z_wall_tubes,
            c=np_pmt_in_module_id_wall_tubes,
            s=5,
            cmap=cm_cat_pmt_in_module,
            norm=norm_cat_pmt_in_module,
            marker='.')
        ax103.set_xlabel('arc along the wall')
        ax103.set_ylabel('z')
        cb_pos_arc_z_disp_wall_tubes = fig103.colorbar(
            pos_arc_z_disp_wall_tubes, ticks=range(20), pad=0.1)
        cb_pos_arc_z_disp_wall_tubes.set_label("pmt in module")
        fig103.savefig(write_dir + "pos_arc_z_disp_wall_tubes.pdf")

        fig104 = plt.figure(num=104, clear=True)
        fig104.set_size_inches(10, 8)
        ax104 = fig104.add_subplot(111)
        pos_arc_z_disp_wall_tubes = ax104.scatter(np_pos_arc_wall_tubes,
                                                  np_pos_z_wall_tubes,
                                                  c=np_wall_row,
                                                  s=5,
                                                  cmap=cm_cat_module_row,
                                                  norm=norm_cat_module_row,
                                                  marker='.')
        ax104.set_xlabel('arc along the wall')
        ax104.set_ylabel('z')
        cb_pos_arc_z_disp_wall_tubes = fig104.colorbar(
            pos_arc_z_disp_wall_tubes, ticks=range(16), pad=0.1)
        cb_pos_arc_z_disp_wall_tubes.set_label("wall module row")
        fig104.savefig(write_dir + "pos_arc_z_disp_wall_tubes_color_row.pdf")

        fig105 = plt.figure(num=105, clear=True)
        fig105.set_size_inches(10, 8)
        ax105 = fig105.add_subplot(111)
        pos_arc_z_disp_wall_tubes = ax105.scatter(np_pos_arc_wall_tubes,
                                                  np_pos_z_wall_tubes,
                                                  c=np_wall_col,
                                                  s=5,
                                                  cmap=cm_cat_module_col,
                                                  norm=norm_cat_module_col,
                                                  marker='.')
        ax105.set_xlabel('arc along the wall')
        ax105.set_ylabel('z')
        cb_pos_arc_z_disp_wall_tubes = fig105.colorbar(
            pos_arc_z_disp_wall_tubes, ticks=range(40), pad=0.1)
        cb_pos_arc_z_disp_wall_tubes.set_label("wall module column")
        fig105.savefig(write_dir + "pos_arc_z_disp_wall_tubes_color_col.pdf")

        fig106 = plt.figure(num=106, clear=True)
        fig106.set_size_inches(10, 8)
        ax106 = fig106.add_subplot(111)
        pos_x_y_disp_top_tubes = ax106.scatter(np_pos_x_top_tubes,
                                               np_pos_y_top_tubes,
                                               c=np_pmt_in_module_id_top_tubes,
                                               s=5,
                                               cmap=cm_cat_pmt_in_module,
                                               norm=norm_cat_pmt_in_module,
                                               marker='.')
        ax106.set_xlabel('x')
        ax106.set_ylabel('y')
        cb_pos_x_y_disp_top_tubes = fig106.colorbar(pos_x_y_disp_top_tubes,
                                                    ticks=range(20),
                                                    pad=0.1)
        cb_pos_x_y_disp_top_tubes.set_label("pmt in module")
        fig106.savefig(write_dir + "pos_x_y_disp_top_tubes.pdf")

        fig107 = plt.figure(num=107, clear=True)
        fig107.set_size_inches(10, 8)
        ax107 = fig107.add_subplot(111)
        pos_x_y_disp_bottom_tubes = ax107.scatter(
            np_pos_x_bottom_tubes,
            np_pos_y_bottom_tubes,
            c=np_pmt_in_module_id_bottom_tubes,
            s=5,
            cmap=cm_cat_pmt_in_module,
            norm=norm_cat_pmt_in_module,
            marker='.')
        ax107.set_xlabel('x')
        ax107.set_ylabel('y')
        cb_pos_x_y_disp_bottom_tubes = fig107.colorbar(
            pos_x_y_disp_bottom_tubes, ticks=range(20), pad=0.1)
        cb_pos_x_y_disp_bottom_tubes.set_label("pmt in module")
        fig107.savefig(write_dir + "pos_x_y_disp_bottom_tubes.pdf")

        Eth = {
            22: 0.786 * 2,
            11: 0.786,
            -11: 0.786,
            13: 158.7,
            -13: 158.7,
            111: 0.786 * 4
        }

        tree.GetEvent(ev)
        wcsimrootsuperevent = tree.wcsimrootevent
        print "number of sub events: " + str(
            wcsimrootsuperevent.GetNumberOfEvents())

        wcsimrootevent = wcsimrootsuperevent.GetTrigger(0)
        tracks = wcsimrootevent.GetTracks()
        energy = []
        position = []
        direction = []
        pid = []
        for i in range(wcsimrootevent.GetNtrack()):
            if tracks[i].GetParenttype() == 0 and tracks[i].GetFlag(
            ) == 0 and tracks[i].GetIpnu() in Eth.keys():
                pid.append(tracks[i].GetIpnu())
                position.append([
                    tracks[i].GetStart(0), tracks[i].GetStart(1),
                    tracks[i].GetStart(2)
                ])
                direction.append([
                    tracks[i].GetDir(0), tracks[i].GetDir(1),
                    tracks[i].GetDir(2)
                ])
                energy.append(tracks[i].GetE())

        biggestTrigger = 0
        biggestTriggerDigihits = 0
        for index in range(wcsimrootsuperevent.GetNumberOfEvents()):
            wcsimrootevent = wcsimrootsuperevent.GetTrigger(index)
            ncherenkovdigihits = wcsimrootevent.GetNcherenkovdigihits()
            if ncherenkovdigihits > biggestTriggerDigihits:
                biggestTriggerDigihits = ncherenkovdigihits
                biggestTrigger = index

        wcsimrootevent = wcsimrootsuperevent.GetTrigger(biggestTrigger)

        wcsimrootevent = wcsimrootsuperevent.GetTrigger(index)

        print "event date and number: " + str(
            wcsimrootevent.GetHeader().GetDate()) + " " + str(
                wcsimrootevent.GetHeader().GetEvtNum())

        ncherenkovhits = wcsimrootevent.GetNcherenkovhits()
        ncherenkovdigihits = wcsimrootevent.GetNcherenkovdigihits()

        print "Ncherenkovdigihits " + str(ncherenkovdigihits)

        if ncherenkovdigihits == 0:
            print "event, trigger has no hits " + str(ev) + " " + str(index)
            return

        np_pos_x = np.zeros((ncherenkovdigihits))
        np_pos_y = np.zeros((ncherenkovdigihits))
        np_pos_z = np.zeros((ncherenkovdigihits))

        np_dir_u = np.zeros((ncherenkovdigihits))
        np_dir_v = np.zeros((ncherenkovdigihits))
        np_dir_w = np.zeros((ncherenkovdigihits))

        np_cylloc = np.zeros((ncherenkovdigihits))
        np_cylloc = np_cylloc - 1000

        np_q = np.zeros((ncherenkovdigihits))
        np_t = np.zeros((ncherenkovdigihits))

        np_pmt_index = np.zeros((ncherenkovdigihits), dtype=np.int32)
        """
        The index starts at 1 and counts up continuously with no gaps
        Each 19 consecutive PMTs belong to one mPMT module, so (index-1)/19 is the module number.
        The index%19 gives the position in the module: 1-12 is the outer ring, 13-18 is the inner ring, 0 is the centre PMT
        The modules are then ordered as follows:
        It starts by going round the second highest ring around the barrel, then the third highest ring, fourth highest ring, all the way down to the lowest ring (i.e. skips the highest ring). Then does the bottom end-cap, row by row (the first row has 6 modules, the second row has 8, then 10, 10, 10, 10, 10, 10, 8, 6). Then the highest ring around the barrel that was skipped before, then the top end-cap, row by row. I'm not sure why it has this somewhat strange order...
        WTF: actually it is: 2, 6, 8 10, 10, 12 and down again in the caps
        """

        for i in range(ncherenkovdigihits):
            wcsimrootcherenkovdigihit = wcsimrootevent.GetCherenkovDigiHits(
            ).At(i)

            hit_q = wcsimrootcherenkovdigihit.GetQ()
            hit_t = wcsimrootcherenkovdigihit.GetT()
            hit_tube_id = wcsimrootcherenkovdigihit.GetTubeId() - 1

            np_pmt_index[i] = hit_tube_id

            #if i<10:
            #    print "q t id: "+str(hit_q)+" "+str(hit_t)+" "+str(hit_tube_id)+" "

            pmt = geo.GetPMT(hit_tube_id)

            #if i<10:
            #    print "pmt tube no: "+str(pmt.GetTubeNo()) #+" " +pmt.GetPMTName()
            #    print "pmt cyl loc: "+str(pmt.GetCylLoc())

            #np_cylloc[i]=pmt.GetCylLoc()

            np_pos_x[i] = pmt.GetPosition(2)
            np_pos_y[i] = pmt.GetPosition(0)
            np_pos_z[i] = pmt.GetPosition(1)

            np_dir_u[i] = pmt.GetOrientation(2)
            np_dir_v[i] = pmt.GetOrientation(0)
            np_dir_w[i] = pmt.GetOrientation(1)

            np_q[i] = hit_q
            np_t[i] = hit_t

        np_module_index = module_index(np_pmt_index)
        np_pmt_in_module_id = pmt_in_module_id(np_pmt_index)

        np_wall_indices = np.where(is_barrel(np_module_index))
        np_top_indices = np.where(is_top(np_module_index))
        np_bottom_indices = np.where(is_bottom(np_module_index))

        np_pos_r = np.hypot(np_pos_x, np_pos_y)
        np_pos_phi = np.arctan2(np_pos_y, np_pos_x)
        np_pos_arc = r_max * np_pos_phi
        np_pos_arc_wall = np_pos_arc[np_wall_indices]

        np_pos_x_top = np_pos_x[np_top_indices]
        np_pos_y_top = np_pos_y[np_top_indices]
        np_pos_z_top = np_pos_z[np_top_indices]

        np_pos_x_bottom = np_pos_x[np_bottom_indices]
        np_pos_y_bottom = np_pos_y[np_bottom_indices]
        np_pos_z_bottom = np_pos_z[np_bottom_indices]

        np_pos_x_wall = np_pos_x[np_wall_indices]
        np_pos_y_wall = np_pos_y[np_wall_indices]
        np_pos_z_wall = np_pos_z[np_wall_indices]

        np_q_top = np_q[np_top_indices]
        np_t_top = np_t[np_top_indices]

        np_q_bottom = np_q[np_bottom_indices]
        np_t_bottom = np_t[np_bottom_indices]

        np_q_wall = np_q[np_wall_indices]
        np_t_wall = np_t[np_wall_indices]

        np_wall_row, np_wall_col = row_col(np_module_index[np_wall_indices])
        np_pmt_in_module_id_wall = np_pmt_in_module_id[np_wall_indices]

        np_wall_data_rect = np.zeros((16, 40, 38))
        np_wall_data_rect[np_wall_row, np_wall_col,
                          np_pmt_in_module_id_wall] = np_q_wall
        np_wall_data_rect[np_wall_row, np_wall_col,
                          np_pmt_in_module_id_wall + 19] = np_t_wall

        np_wall_data_rect_ev = np.expand_dims(np_wall_data_rect, axis=0)

        np_wall_q_max_module = np.amax(np_wall_data_rect[:, :, 0:19], axis=-1)
        np_wall_q_sum_module = np.sum(np_wall_data_rect[:, :, 0:19], axis=-1)

        max_q = np.amax(np_q)
        np_scaled_q = 500 * np_q / max_q

        np_dir_u_scaled = np_dir_u * np_scaled_q
        np_dir_v_scaled = np_dir_v * np_scaled_q
        np_dir_w_scaled = np_dir_w * np_scaled_q

        fig1 = plt.figure(num=1, clear=True)
        fig1.set_size_inches(10, 8)
        ax1 = fig1.add_subplot(111, projection='3d', azim=35, elev=20)
        ev_disp = ax1.scatter(np_pos_x,
                              np_pos_y,
                              np_pos_z,
                              c=np_q,
                              s=2,
                              alpha=0.4,
                              cmap=cm,
                              marker='.')
        ax1.set_xlabel('x')
        ax1.set_ylabel('y')
        ax1.set_zlabel('z')
        cb_ev_disp = fig1.colorbar(ev_disp, pad=0.03)
        cb_ev_disp.set_label("charge")
        fig1.savefig(write_dir + "ev_disp_ev_{}_trig_{}.pdf".format(ev, index))

        fig2 = plt.figure(num=2, clear=True)
        fig2.set_size_inches(10, 8)
        ax2 = fig2.add_subplot(111, projection='3d', azim=35, elev=20)
        colors = plt.cm.spring(norm(np_t))
        ev_disp_q = ax2.quiver(np_pos_x,
                               np_pos_y,
                               np_pos_z,
                               np_dir_u_scaled,
                               np_dir_v_scaled,
                               np_dir_w_scaled,
                               colors=colors,
                               alpha=0.4,
                               cmap=cm)
        ax2.set_xlabel('x')
        ax2.set_ylabel('y')
        ax2.set_zlabel('z')
        sm = matplotlib.cm.ScalarMappable(cmap=cm, norm=norm)
        sm.set_array([])
        cb_ev_disp_2 = fig2.colorbar(sm, pad=0.03)
        cb_ev_disp_2.set_label("time")
        fig2.savefig(write_dir +
                     "ev_disp_quiver_ev_{}_trig_{}.pdf".format(ev, index))

        fig3 = plt.figure(num=3, clear=True)
        fig3.set_size_inches(10, 8)
        ax3 = fig3.add_subplot(111)
        ev_disp_wall = ax3.scatter(np_pos_arc_wall,
                                   np_pos_z_wall,
                                   c=np_q_wall,
                                   s=2,
                                   cmap=cm,
                                   marker='.')
        ax3.set_xlabel('arc along the wall')
        ax3.set_ylabel('z')
        cb_ev_disp_wall = fig3.colorbar(ev_disp_wall, pad=0.1)
        cb_ev_disp_wall.set_label("charge")
        fig3.savefig(write_dir +
                     "ev_disp_wall_ev_{}_trig_{}.pdf".format(ev, index))

        fig4 = plt.figure(num=4, clear=True)
        fig4.set_size_inches(10, 8)
        ax4 = fig4.add_subplot(111)
        ev_disp_top = ax4.scatter(np_pos_x_top,
                                  np_pos_y_top,
                                  c=np_q_top,
                                  s=2,
                                  cmap=cm,
                                  marker='.')
        ax4.set_xlabel('x')
        ax4.set_ylabel('y')
        cb_ev_disp_top = fig4.colorbar(ev_disp_top, pad=0.1)
        cb_ev_disp_top.set_label("charge")
        fig4.savefig(write_dir +
                     "ev_disp_top_ev_{}_trig_{}.pdf".format(ev, index))

        fig5 = plt.figure(num=5, clear=True)
        fig5.set_size_inches(10, 8)
        ax5 = fig5.add_subplot(111)
        ev_disp_bottom = ax5.scatter(np_pos_x_bottom,
                                     np_pos_y_bottom,
                                     c=np_q_bottom,
                                     s=2,
                                     cmap=cm,
                                     marker='.')
        ax5.set_xlabel('x')
        ax5.set_ylabel('y')
        cb_ev_disp_bottom = fig5.colorbar(ev_disp_bottom, pad=0.1)
        cb_ev_disp_bottom.set_label("charge")
        fig5.savefig(write_dir +
                     "ev_disp_bottom_ev_{}_trig_{}.pdf".format(ev, index))

        fig6 = plt.figure(num=6, clear=True)
        fig6.set_size_inches(10, 4)
        ax6 = fig6.add_subplot(111)
        q_sum_disp = ax6.imshow(np.flip(np_wall_q_sum_module, axis=0), cmap=cm)
        ax6.set_xlabel('arc index')
        ax6.set_ylabel('z index')
        cb_q_sum_disp = fig6.colorbar(q_sum_disp, pad=0.1)
        cb_q_sum_disp.set_label("total charge in module")
        fig6.savefig(write_dir +
                     "q_sum_disp_ev_{}_trig_{}.pdf".format(ev, index))

        fig7 = plt.figure(num=7, clear=True)
        fig7.set_size_inches(10, 4)
        ax7 = fig7.add_subplot(111)
        q_max_disp = ax7.imshow(np.flip(np_wall_q_max_module, axis=0), cmap=cm)
        ax7.set_xlabel('arc index')
        ax7.set_ylabel('z index')
        cb_q_max_disp = fig7.colorbar(q_max_disp, pad=0.1)
        cb_q_max_disp.set_label("maximum charge in module")
        fig7.savefig(write_dir +
                     "q_max_disp_ev_{}_trig_{}.pdf".format(ev, index))

        fig8 = plt.figure(num=8, clear=True)
        fig8.set_size_inches(10, 8)
        ax8 = fig8.add_subplot(111)
        plt.hist(np_q, 50, density=True, facecolor='blue', alpha=0.75)
        ax8.set_xlabel('charge')
        ax8.set_ylabel("PMT's above threshold")
        fig8.savefig(write_dir +
                     "q_pmt_disp_ev_{}_trig_{}.pdf".format(ev, index))

        fig9 = plt.figure(num=9, clear=True)
        fig9.set_size_inches(10, 8)
        ax9 = fig9.add_subplot(111)
        plt.hist(np_t, 50, density=True, facecolor='blue', alpha=0.75)
        ax9.set_xlabel('time')
        ax9.set_ylabel("PMT's above threshold")
        fig9.savefig(write_dir +
                     "t_pmt_disp_ev_{}_trig_{}.pdf".format(ev, index))

        fig10 = plt.figure(num=10, clear=True)
        fig10.set_size_inches(15, 5)
        grid_q = ImageGrid(
            fig10,
            111,
            nrows_ncols=(4, 5),
            axes_pad=0.0,
            share_all=True,
            label_mode="L",
            cbar_location="top",
            cbar_mode="single",
        )
        for i in range(19):
            q_disp = grid_q[i].imshow(np.flip(np_wall_data_rect[:, :, i],
                                              axis=0),
                                      cmap=cm)
            q_disp = grid_q[19].imshow(np.flip(np_wall_q_max_module, axis=0),
                                       cmap=cm)
            grid_q.cbar_axes[0].colorbar(q_disp)

        fig10.savefig(write_dir +
                      "q_disp_grid_ev_{}_trig_{}.pdf".format(ev, index))

        fig11 = plt.figure(num=11, clear=True)
        fig11.set_size_inches(15, 5)
        grid_t = ImageGrid(
            fig11,
            111,
            nrows_ncols=(4, 5),
            axes_pad=0.0,
            share_all=True,
            label_mode="L",
            cbar_location="top",
            cbar_mode="single",
        )
        for i in range(19):
            t_disp = grid_t[i].imshow(np.flip(np_wall_data_rect[:, :, i + 19],
                                              axis=0),
                                      cmap=cm)

        fig11.savefig(write_dir +
                      "t_disp_grid_ev_{}_trig_{}.pdf".format(ev, index))

    wl.close()
Beispiel #10
0
def ransac_im_fit(im,
                  mode=1,
                  residual_threshold=0.1,
                  min_samples=10,
                  max_trials=1000,
                  model_f=None,
                  p0=None,
                  mask=None,
                  scale=False,
                  fract=1,
                  param_dict=None,
                  plot=False,
                  axes=(-2, -1),
                  widget=None):
    '''
    Fits a plane, polynomial, convex paraboloid, arbitrary function, or
    smoothing spline to an image using the RANSAC algorithm.

    Parameters
    ----------
    im : ndarray
        ndarray with images to fit to.
    mode : integer [0-4]
        Specifies model used for fit.
        0 is function defined by `model_f`.
        1 is plane.
        2 is quadratic.
        3 is concave paraboloid with offset.
        4 is smoothing spline.
    model_f : callable or None
        Function to be fitted.
        Definition is model_f(p, *args), where p is 1-D iterable of params
        and args is iterable of (x, y, z) arrays of point cloud coordinates.
        See examples.
    p0 : tuple
        Initial guess of fit params for `model_f`.
    mask : 2-D boolean array
        Array with which to mask data. True values are ignored.
    scale : bool
        If True, `residual_threshold` is scaled by stdev of `im`.
    fract : scalar (0, 1]
        Fraction of data used for fitting, chosen randomly. Non-used data
        locations are set as nans in `inliers`.
    residual_threshold : float
        Maximum distance for a data point to be classified as an inlier.
    min_samples : int or float
        The minimum number of data points to fit a model to.
        If an int, the value is the number of pixels.
        If a float, the value is a fraction (0.0, 1.0] of the total number of pixels.
    max_trials : int, optional
        Maximum number of iterations for random sample selection.
    param_dict : None or dictionary.
        If not None, the dictionary is passed to the model estimator.
        For arbitrary functions, this is passed to scipy.optimize.leastsq.
        For spline fitting, this is passed to scipy.interpolate.SmoothBivariateSpline.
        All other models take no parameters.
    plot : bool
        If True, the data, including inliers, model, etc are plotted.
    axes : length 2 iterable
        Indices of the input array with images.
    widget : QtWidget
        A qt widget that must contain as many widgets or dock widget as there is popup window

    Returns
    -------
    Tuple of fit, inliers, n, where:
    fit : 2-D array
        Image of fitted model.
    inliers : 2-D array
        Boolean array describing inliers.
    n : array or None
        Normal of plane fit. `None` for other models.

    Notes
    -----
    See skimage.measure.ransac for details of RANSAC algorithm.

    `min_samples` should be chosen appropriate to the size of the image
    and to the variation in the image.

    Increasing `residual_threshold` increases the fraction of the image
    fitted to.

    The entire image can be fitted to without RANSAC by setting:
    max_trials=1, min_samples=1.0, residual_threshold=`x`, where `x` is a
    suitably large value.

    Examples
    --------
    `model_f` for paraboloid with offset:

    >>> def model_f(p, *args):
    ...     (x, y, z) = args
    ...     m = np.abs(p[0])*((x-p[1])**2 + (y-p[2])**2) + p[3]
    ...     return m
    >>> p0 = (0.1, 10, 20, 0)


    To plot fit, inliers etc:

    >>> from fpd.ransac_tools import ransac_im_fit
    >>> import matplotlib as mpl
    >>> from numpy.ma import masked_where
    >>> import numpy as np
    >>> import matplotlib.pylab as plt
    >>> plt.ion()


    >>> cmap = mpl.cm.gray
    >>> cmap.set_bad('r')

    >>> image = np.random.rand(*(64,)*2)
    >>> fit, inliers, n = ransac_im_fit(image, mode=1)
    >>> cor_im = image-fit

    >>> pct = 0.5
    >>> vmin, vmax = np.percentile(cor_im, [pct, 100-pct])
    >>>
    >>> f, axs = plt.subplots(1, 4, sharex=True, sharey=True)
    >>> _ = axs[0].matshow(image, cmap=cmap)
    >>> _ = axs[1].matshow(masked_where(inliers==False, image), cmap=cmap)
    >>> _ = axs[2].matshow(fit, cmap=cmap)
    >>> _ = axs[3].matshow(cor_im, vmin=vmin, vmax=vmax)



    To plot plane normal vs threshold:

    >>> from fpd.ransac_tools import ransac_im_fit
    >>> from numpy.ma import masked_where
    >>> import numpy as np
    >>> from tqdm import tqdm
    >>> import matplotlib.pylab as plt
    >>> plt.ion()

    >>> image = np.random.rand(*(64,)*2)
    >>> ns = []
    >>> rts = np.logspace(0, 1.5, 5)
    >>> for rt in tqdm(rts):
    ...     nis = []
    ...     for i in range(64):
    ...         fit, inliers, n = ransac_im_fit(image, residual_threshold=rt, max_trials=10)
    ...         nis.append(n)
    ...     ns.append(np.array(nis).mean(0))
    >>> ns = np.array(ns)

    >>> thx = np.arctan2(ns[:,1], ns[:,2])
    >>> thy = np.arctan2(ns[:,0], ns[:,2])
    >>> thx = np.rad2deg(thx)
    >>> thy = np.rad2deg(thy)

    >>> _ = plt.figure()
    >>> _ = plt.semilogx(rts, thx)
    >>> _ = plt.semilogx(rts, thy)

    '''

    # Set model
    # Functions defining classes are needed to pass parameters since class must
    # not be instantiated or are monkey patched (only in spline implementation)
    if mode == 0:
        # generate model_class with passed function
        if p0 is None:
            raise NotImplementedError('p0 must be specified.')
        model_class = rt._model_class_gen(model_f, p0, param_dict)
    elif mode == 1:
        # linear
        model_class = rt._Plane3dModel
    elif mode == 2:
        # quadratic
        model_class = rt._Poly3dModel
    elif mode == 3:
        # concave paraboloid
        model_class = rt._Poly3dParaboloidModel
    elif mode == 4:
        # spline
        class _Spline3dModel_monkeypatched(_Spline3dModel):
            pass

        model_class = _Spline3dModel_monkeypatched
        model_class.param_dict = param_dict

    multiim = False
    if im.ndim > 2:
        multiim = True
        axes = [int(el) for el in axes]
        ims, unflat_shape = seq_image_array(im, axes)
        pbar = tqdm(total=ims.shape[0])
    else:
        ims = im[None]

    fits = []
    inlierss = []
    ns = []

    for imi in ims:
        # set data
        yy, xx = np.indices(imi.shape)
        zz = imi
        if mask is None:
            keep = (np.ones_like(imi) == 1).flatten()
        else:
            keep = (mask == False).flatten()
        data = np.column_stack([xx.flat[keep], yy.flat[keep], zz.flat[keep]])

        if type(min_samples) is int:
            # take number directly
            pass
        else:
            # take number as fraction
            min_samples = int(len(keep) * min_samples)
            print("min_samples is set to: %d" % (min_samples))

        # randomly select data
        sel = np.random.rand(data.shape[0]) <= fract
        data = data[sel.flatten()]

        # scale residual, if chosen
        if scale:
            residual_threshold = residual_threshold * np.std(data[:, 2])

        # determine if fitting to all
        full_fit = min_samples == data.shape[0]

        if not full_fit:
            # do ransac fit
            model, inliers = ransac(data=data,
                                    model_class=model_class,
                                    min_samples=min_samples,
                                    residual_threshold=residual_threshold,
                                    max_trials=max_trials)
        else:
            model = model_class()
            inliers = np.ones(data.shape[0]) == 1

        # get params from fit with all inliers
        model.estimate(data[inliers])
        # get model over all x, y
        args = (xx.flatten(), yy.flatten(), zz.flatten())
        fit = model.my_model(model.params, *args).reshape(imi.shape)

        if mask is None and fract == 1:
            inliers = inliers.reshape(imi.shape)
        else:
            inliers_nans = np.empty_like(imi).flatten()
            inliers_nans[:] = np.nan
            yi = np.indices(inliers_nans.shape)[0]

            sel_fit = yi[keep][sel.flatten()]
            inliers_nans[sel_fit] = inliers
            inliers = inliers_nans.reshape(imi.shape)

        # calculate normal for plane
        if mode == 1:
            # linear
            C = model.params
            n = np.array([-C[0], -C[1], 1])
            n_mag = np.linalg.norm(n, ord=None, axis=0)
            n = n / n_mag
        else:
            # non-linear
            n = None

        if plot:
            import matplotlib.pylab as plt
            import matplotlib as mpl
            from numpy.ma import masked_where
            from mpl_toolkits.axes_grid1 import ImageGrid

            plt.ion()
            cmap = mpl.cm.gray
            cmap.set_bad('r')

            cor_im = imi - fit
            pct = 0.1
            vmin, vmax = np.percentile(cor_im, [pct, 100 - pct])
            if widget is None:
                fig = plt.figure()
            else:
                docked = widget.setup_docking("Circular Center", "Bottom")
                fig = docked.get_fig()
                fig.clf()

            grid = ImageGrid(fig,
                             111,
                             nrows_ncols=(1, 4),
                             axes_pad=0.1,
                             share_all=True,
                             label_mode="L",
                             cbar_location="right",
                             cbar_mode="single")

            images = [imi, masked_where(inliers == False, imi), fit, cor_im]
            titles = ['Image', 'Inliers', 'Fit', 'Corrected']
            for i, image in enumerate(images):
                img = grid[i].imshow(image, cmap=cmap, interpolation='nearest')
                grid[i].set_title(titles[i])
            img.set_clim(vmin, vmax)
            grid.cbar_axes[0].colorbar(img)

            #f, axs = plt.subplots(1, 4, sharex=True, sharey=True)
            #_ = axs[0].matshow(imi, cmap=cmap)
            #_ = axs[1].matshow(masked_where(inliers==False, imi), cmap=cmap)
            #_ = axs[2].matshow(fit, cmap=cmap)
            #_ = axs[3].matshow(cor_im, vmin=vmin, vmax=vmax)

            # for i, title in enumerate(['Image' , 'Inliers', 'Fit', 'Corrected']):
            # axs[i].set_title(title)
            # plt.tight_layout()
        fits.append(fit)
        inlierss.append(inliers)
        ns.append(n)
        if multiim:
            pbar.update(1)
    fit = np.array(fits)
    inliers = np.array(inlierss)
    n = np.array(ns)

    if multiim:
        pbar.close()

        # reshape
        fit = unseq_image_array(fit, axes, unflat_shape)
        inliers = unseq_image_array(inliers, axes, unflat_shape)
        n = unseq_image_array(n, axes, unflat_shape)
    else:
        fit = fit[0]
        inliers = inliers[0]
        n = n[0]

    return (fit, inliers, n)
Beispiel #11
0
    def test_calc_likelihoods_2():
        from numpy.testing import assert_array_almost_equal
        from numpy.testing import assert_almost_equal
        from myimage_analysis import calculate_nhi
        import matplotlib.pyplot as plt
        from matplotlib import cm
        from mpl_toolkits.axes_grid1 import ImageGrid

        av_image = np.array([[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 2, 1, 0],
                             [np.nan, 1, 1, 1, 0], [0, 0, 0, 0, 0]])

        #av_image_error = np.random.normal(0.1, size=av_image.shape)
        av_image_error = 0.1 * np.ones(av_image.shape)

        #nhi_image = av_image + np.random.normal(0.1, size=av_image.shape)
        hi_cube = np.zeros((5, 5, 5))

        # make inner channels correlated with av
        hi_cube[:, :, :] = np.array([
            [
                [1., 0., 0., 0., 0.],
                [np.nan, 0., 0., 0., 0.],
                [0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0.],
                [1., 0., 0., 0., 10.],
            ],
            [
                [0., 0., 0., 0., 0.],
                [0., 0., 2., 0., 0.],
                [0., 0., 4., 0., 0.],
                [0., 0., 2., 0., 0.],
                [0., 0., 0., 0., 0.],
            ],
            [
                [0., 0., 0., 0., 0.],
                [0., 0., 0., 2., 0.],
                [0., 0., 0., 2., 0.],
                [0., 0., 0., 2., np.nan],
                [0., 0., 0., 0., 0.],
            ],
            [
                [0., 0., 0., 0., 0.],
                [0., 2., 0., 0., 0.],
                [0., 2., 0., 0., 0.],
                [0., 2., 0., 0., 0.],
                [0., 0., 0., 0., 0.],
            ],
            [
                [0., 0., 0., 0., 0.],
                [0., 0., 0., 0., np.nan],
                [0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0.],
                [1., 0., 0., 0., 0.2],
            ],
        ])

        if 1:
            fig = plt.figure(figsize=(4, 4))
            imagegrid = ImageGrid(fig, (1, 1, 1),
                                  nrows_ncols=(1, 5),
                                  ngrids=5,
                                  cbar_mode="single",
                                  cbar_location='top',
                                  cbar_pad="2%",
                                  cbar_size='3%',
                                  axes_pad=0.1,
                                  aspect=True,
                                  label_mode='L',
                                  share_all=True)
            cmap = cm.get_cmap('Greys', 5)
            for i in xrange(5):
                im = imagegrid[i].imshow(
                    hi_cube[i, :, :],
                    origin='lower',
                    #aspect='auto',
                    cmap=cmap,
                    interpolation='none',
                    vmin=0,
                    vmax=4)
            #cb = imagegrid[i].cax.colorbar(im)
            cbar = imagegrid.cbar_axes[0].colorbar(im)
            #plt.title('HI Cube')
            plt.savefig('/usr/users/ezbc/Desktop/hi_cube.png')

        # make edge channels totally uncorrelated
        #hi_cube[(0, 4), :, :] = np.arange(0, 25).reshape(5,5)
        #hi_cube[(0, 4), :, :] = - np.ones((5,5))

        hi_vel_axis = np.arange(0, 5, 1)

        # add intercept
        intercept_answer = 0.9
        av_image = av_image + intercept_answer

        if 1:
            fig = plt.figure(figsize=(4, 4))
            params = {
                'figure.figsize': (1, 1),
                #'figure.titlesize': font_scale,
            }
            plt.rcParams.update(params)
            imagegrid = ImageGrid(fig, (1, 1, 1),
                                  nrows_ncols=(1, 1),
                                  ngrids=1,
                                  cbar_mode="single",
                                  cbar_location='top',
                                  cbar_pad="2%",
                                  cbar_size='3%',
                                  axes_pad=0.1,
                                  aspect=True,
                                  label_mode='L',
                                  share_all=True)
            cmap = cm.get_cmap('Greys', 5)
            im = imagegrid[0].imshow(
                av_image,
                origin='lower',
                #aspect='auto',
                cmap=cmap,
                interpolation='none',
                vmin=0,
                vmax=4)
            #cb = imagegrid[i].cax.colorbar(im)
            cbar = imagegrid.cbar_axes[0].colorbar(im)
            #plt.title('HI Cube')
            plt.savefig('/usr/users/ezbc/Desktop/av.png')

        width_grid = np.arange(0, 5, 1)
        dgr_grid = np.arange(0, 1, 0.1)
        intercept_grid = np.arange(-1, 1, 0.1)
        vel_center = 2

        results = \
            cloudpy._calc_likelihoods(
                              hi_cube=hi_cube / 1.832e-2,
                              hi_vel_axis=hi_vel_axis,
                              vel_center=vel_center,
                              av_image=av_image,
                              av_image_error=av_image_error,
                              width_grid=width_grid,
                              dgr_grid=dgr_grid,
                              intercept_grid=intercept_grid,
                              )

        dgr_answer = 1 / 2.0
        width_answer = 2
        width = results['width_max']
        dgr = results['dgr_max']
        intercept = results['intercept_max']
        print width

        if 0:
            width = width_answer
            intercept = intercept_answer
            dgr = dgr_answer

        vel_range = (vel_center - width / 2.0, vel_center + width / 2.0)

        nhi_image = calculate_nhi(cube=hi_cube,
                                  velocity_axis=hi_vel_axis,
                                  velocity_range=vel_range) / 1.823e-2
        if 1:
            fig = plt.figure(figsize=(4, 4))
            imagegrid = ImageGrid(fig, (1, 1, 1),
                                  nrows_ncols=(1, 1),
                                  ngrids=1,
                                  cbar_mode="single",
                                  cbar_location='top',
                                  cbar_pad="2%",
                                  cbar_size='3%',
                                  axes_pad=0.1,
                                  aspect=True,
                                  label_mode='L',
                                  share_all=True)
            cmap = cm.get_cmap('Greys', 5)
            im = imagegrid[0].imshow(
                nhi_image,
                origin='lower',
                #aspect='auto',
                cmap=cmap,
                interpolation='none',
                vmin=0,
                vmax=4)
            #cb = imagegrid[i].cax.colorbar(im)
            cbar = imagegrid.cbar_axes[0].colorbar(im)
            #plt.title('HI Cube')
            plt.savefig('/usr/users/ezbc/Desktop/nhi.png')
        if 1:
            fig = plt.figure(figsize=(4, 4))
            imagegrid = ImageGrid(fig, (1, 1, 1),
                                  nrows_ncols=(1, 1),
                                  ngrids=1,
                                  cbar_mode="single",
                                  cbar_location='top',
                                  cbar_pad="2%",
                                  cbar_size='3%',
                                  axes_pad=0.1,
                                  aspect=True,
                                  label_mode='L',
                                  share_all=True)
            cmap = cm.get_cmap('Greys', 5)
            im = imagegrid[0].imshow(
                nhi_image * dgr + intercept,
                origin='lower',
                #aspect='auto',
                cmap=cmap,
                interpolation='none',
                vmin=0,
                vmax=4)
            #cb = imagegrid[i].cax.colorbar(im)
            cbar = imagegrid.cbar_axes[0].colorbar(im)
            #plt.title('HI Cube')
            plt.savefig('/usr/users/ezbc/Desktop/av_model.png')

        print('residuals = ')
        print(av_image - (nhi_image * dgr + intercept))
        print('dgr', dgr)
        print('intercept', intercept)
        print('width', width)

        assert_almost_equal(results['intercept_max'], intercept_answer)
        assert_almost_equal(results['dgr_max'], dgr_answer)
        assert_almost_equal(results['width_max'], width_answer)
Beispiel #12
0
fig = plt.figure(figsize=(6, 6))

# Prepare images
Z = cbook.get_sample_data("axes_grid/bivariate_normal.npy", np_load=True)
extent = (-3, 4, -4, 3)
ZS = [Z[i::3, :] for i in range(3)]
extent = extent[0], extent[1] / 3., extent[2], extent[3]

# *** Demo 1: colorbar at each axes ***
grid = ImageGrid(
    fig,
    211,  # similar to subplot(211)
    nrows_ncols=(1, 3),
    axes_pad=0.05,
    label_mode="1",
    share_all=True,
    cbar_location="top",
    cbar_mode="each",
    cbar_size="7%",
    cbar_pad="1%",
)

for i, (ax, z) in enumerate(zip(grid, ZS)):
    im = ax.imshow(z, origin="lower", extent=extent)
    cb = ax.cax.colorbar(im)
    # Changing the colorbar ticks
    if i in [1, 2]:
        cb.set_ticks([-1, 0, 1])

for ax, im_title in zip(grid, ["Image 1", "Image 2", "Image 3"]):
    t = add_inner_title(ax, im_title, loc='lower left')
Beispiel #13
0
y_test = np_utils.to_categorical(y_test, num_classes)

# matplotlib inline
#
cat_y = np.argmax(y_train, 1)
f, ax = plt.subplots()
n, bins, patches = ax.hist(cat_y, bins=range(11), align='left', rwidth=0.5)
ax.set_xticks(bins[:-1])
ax.set_xlabel('Class Id')
ax.set_ylabel('# Samples')
ax.set_title('CIFAR-10 Class Distribution')

fig = plt.figure(figsize=(10., 10.))
grid = ImageGrid(
    fig,
    111,  # similar to subplot(111)
    nrows_ncols=(10, 10),  # creates 10x10 grid of axes
    axes_pad=(0.03, 0.1),  # pad between axes in inch
)

for i in range(10):
    cls_id = i
    cat_im = X_train[cat_y == cls_id]

    for j in range(10):
        im = cat_im[j]
        im = im.squeeze()
        ax = grid[10 * i + j]
        ax.imshow(im, cmap='gray')
        ax.axis('off')
        ax.grid(False)
        print(j)
Beispiel #14
0
def tiled_axis(ts, filename=None):
    fig = Figure((2.56 * 4, 2.56 * 4), 300)
    canvas = FigureCanvas(fig)
    #ax = fig.add_subplot(111)

    grid = ImageGrid(
        fig,
        111,  # similar to subplot(111)
        nrows_ncols=(3, 1),
        axes_pad=0.5,
        add_all=True,
        label_mode="L",
    )
    # pad half a day so that major ticks show up in the middle, not on the edges
    delta = dates.relativedelta(days=2, hours=12)
    # XXX: gets a list of days.
    timestamps = glucose.get_days(ts.time)
    xmin, xmax = (timestamps[0] - delta, timestamps[-1] + delta)

    fig.autofmt_xdate()

    def make_plot(ax, limit):

        preferspan = ax.axhspan(SAFE[0],
                                SAFE[1],
                                facecolor='g',
                                alpha=0.2,
                                edgecolor='#003333',
                                linewidth=1)

    def draw_glucose(ax, limit):
        xmin, xmax = limit
        # visualize glucose using stems
        ax.set_xlim([xmin, xmax])
        markers, stems, baselines = ax.stem(ts.time, ts.value, linefmt='b:')
        plt.setp(markers, color='red', linewidth=.5, marker='o')
        plt.setp(baselines, marker='None')

    def draw_title(ax, limit):
        ax.set_title('glucose history')

    def get_axis(ax, limit):
        xmin, xmax = limit
        ax.set_xlim([xmin, xmax])

        ax.grid(True)
        #ax.set_ylim( [ ts.value.min( ) *.85 , 600 ] )
        #ax.set_xlabel('time')

        majorLocator = dates.DayLocator()
        majorFormatter = dates.AutoDateFormatter(majorLocator)

        minorLocator = dates.HourLocator(interval=6)
        minorFormatter = dates.AutoDateFormatter(minorLocator)

        #ax.xaxis.set_major_locator(majorLocator)
        #ax.xaxis.set_major_formatter(majorFormatter)

        ax.xaxis.set_minor_locator(minorLocator)
        ax.xaxis.set_minor_formatter(minorFormatter)

        labels = ax.get_xminorticklabels()
        plt.setp(labels, rotation=30, fontsize='small')
        plt.setp(ax.get_xmajorticklabels(), rotation=30, fontsize='medium')

        xmin, xmax = ax.get_xlim()

        log.info(
            pformat({
                'xlim': [dates.num2date(xmin),
                         dates.num2date(xmax)],
                'xticks': dates.num2date(ax.get_xticks()),
            }))

    for i, day in enumerate(timestamps):
        ax = grid[i]
        get_axis(ax, [day, day + delta])
        name = '%s-%d.png' % (day.isoformat(), i)
        #fig.savefig( name )
        canvas.print_figure(name)
        # fig.clf()
        #make_plot( ax,

    #ax.set_ylabel('glucose mm/dL')
    return canvas
Beispiel #15
0
def plot_segmask_3cl_input(y_pred,
                           x_in,
                           y_true=None,
                           class_to_plot=(2, 3),
                           input_to_plot=1,
                           input_name='channel 1',
                           xtick_int=50,
                           ytick_int=50,
                           show_plt=True,
                           save_imag=True,
                           imag_name='pred_mask_input',
                           save_as='pdf'):
    '''Function to plot segmentation mask (3 classes)
       including 1 input channel,
       this is a prototype version ==> can be combined with plot_segmask_input'''

    m_temp = np.argmax(y_pred, axis=-1) + 1
    pred_mask = ((m_temp * (m_temp == class_to_plot[0])) *
                 (0.5 / class_to_plot[0])) + ((m_temp *
                                               (m_temp == class_to_plot[1])) *
                                              (1.0 / class_to_plot[1]))
    if y_true is not None:
        # Plot prediction mask, ground truth and input image (1 channel)
        g_temp = np.argmax(y_true, axis=-1) + 1
        gr_truth = ((g_temp * (g_temp == class_to_plot[0])) *
                    (0.5 / class_to_plot[0])) + (
                        (g_temp * (g_temp == class_to_plot[1])) *
                        (1.0 / class_to_plot[1]))
        grid_cmap = ['jet', 'gray', 'gray']
        grid_imag = [pred_mask, gr_truth, x_in[..., input_to_plot - 1]]
        grid_title = [r'Segmentation mask', r'Ground truth', input_name]
        fig_width = 10.5
        n_cols = 3
    else:
        # Plot prediction mask and input image (1 channel)
        grid_cmap = ['jet', 'gray']
        grid_imag = [pred_mask, x_in[..., input_to_plot - 1]]
        grid_title = [r'Segmentation mask', input_name]
        fig_width = 6.8
        n_cols = 2

    fig = plt.figure(figsize=(fig_width, 4))
    grid = ImageGrid(fig,
                     rect=[0.1, 0.07, 0.85, 0.9],
                     nrows_ncols=(1, n_cols),
                     axes_pad=0.25,
                     share_all=True)
    for i in range(0, n_cols):
        ax = grid[i]
        ax.imshow(grid_imag[i], vmin=0, vmax=1, cmap=grid_cmap[i])
        ax.set_xticks(np.arange(0, pred_mask.shape[1] + 1, xtick_int))
        ax.set_yticks(np.arange(0, pred_mask.shape[0] + 1, ytick_int))
        ax.set_xlabel(r'image width [pixel]')
        ax.set_ylabel(r'image height [pixel]')
        ax.set_title(grid_title[i])

    if show_plt == True:
        plt.show()
    if save_imag == True:
        plt.savefig(imag_name + '.' + save_as, bbox_inches='tight')
        if show_plt == False:
            # Clear memory (or matplotlib history) although the figure
            # is not shown
            plt.close()
Beispiel #16
0
def make_gallery(files=None,
                 ims=None,
                 heads=None,
                 outname=None,
                 pfovs=None,
                 scale="lin",
                 norm=None,
                 absmin=None,
                 absmax=None,
                 permax=None,
                 permin=None,
                 papercol=2,
                 ncols=4,
                 pwidth=None,
                 nrows=None,
                 view_as=None,
                 view_px=None,
                 inv=True,
                 cmap='gist_heat',
                 colorbar=None,
                 cbarwidth=0.03,
                 cbarlabel=None,
                 cbarlabelpad=None,
                 xlabel=None,
                 ylabel=None,
                 xlabelpad=None,
                 ylabelpad=None,
                 titles=None,
                 titcols='black',
                 titx=0.5,
                 tity=0.95,
                 titvalign='top',
                 tithalign='center',
                 subtitles=None,
                 subtitcols='black',
                 subtitx=0.5,
                 subtity=0.05,
                 subtitvalign='bottom',
                 subtithalign='center',
                 plotsym=None,
                 contours=None,
                 alphamap=None,
                 cbartickinterval=None,
                 sbarlength=None,
                 sbarunit='as',
                 sbarcol='black',
                 sbarthick='2',
                 sbarpos=[0.1, 0.05],
                 majtickinterval=None,
                 verbose=False,
                 latex=True,
                 texts=None,
                 textcols='black',
                 textx=0.05,
                 texty=0.95,
                 textvalign='top',
                 texthalign='left',
                 textsize=None,
                 replace_NaNs=None,
                 smooth=None,
                 lines=None,
                 latexfontsize=16,
                 axes_pad=0.05):
    """
    MISSING:
        - implementation to rotate images to North up (and East to the left
          if desired and necessary

    The purpose of this routine is to create puplication-quality multiplots of
    images with a maximum number of customisability. The following parameters
    can be set either for all images a single variable or as an array with the
    same number of elements to set the parameters individually
    INPUT:
        - flles : list of fits files to be plotted
        - pfovs : pixel size in arcsec for the images. If not provided then it
                  will be looked for in the fits headers
        - log : enable logarithmic scaling
        - norm: provide the index of the image that should be normalised to
        - permax: set a percentile value which should be used as the maximum
                  in the colormap
        - papercol: either 1 for a plot fitting into one column in the paper or
                    2 in case the plot is supposed to go over the full page
        - ncols: number of columns of images in the multiplot
        - nrows: number of rows of images in the multiplot
        - view_as: size of the field of view that should be plotted in arcsec
        - inv: if set true then the used colormap is inverted
        - cmap: colormap to be used
        - xlabelpad, xlabelpad : adjust the position of the x,y axis label
        - titles: provide titles for the individual images
        - titcols: colors for the titles
        - titx, tity: x,y position of the title
        - titvalign, tithalign: vertical and horizontal alignment of the title
                                text with respect to the given position
        - subtitles: similar to titles but for another text in the image


    """

    # --- read in the images in a 2D list
    if ims is None:

        if type(files) == list:
            n = len(files)

            if nrows is None and ncols is not None:
                nrows = int(np.ceil(1.0 * n / ncols))

            if ncols is None and nrows is not None:
                ncols = int(np.ceil(1.0 * n / nrows))

            # --- reshape file list into 2D array
            rest = n % nrows
            for i in range((nrows - rest) % nrows):
                files.append(None)

            files = np.reshape(files, (nrows, ncols))

        elif type(files) == np.ndarray:

            s = np.shape(files)
            if len(s) > 1:
                if nrows is None:
                    nrows = s[0]
                if ncols is None:
                    ncols = s[1]
                n = nrows * ncols

        else:
            n = 1
            nrows = 1
            ncols = 1

            files = np.reshape(files, (-1, ncols))

        # --- create a fill a 2D image array
#        ims = [[None]*ncols]*nrows
#        ims = [None]*nrows*ncols
        ims = [[None] * ncols for _ in range(nrows)]
        #print(np.shape(ims))

        for r in range(nrows):
            for c in range(ncols):
                if files[r][c] != None:
                    if files[r][c] != "":
                        i = ncols * r + c
                        ims[r][c] = fits.getdata(files[r][c], ext=0)

    # images are provides
    else:

        s = np.shape(ims)

        if type(ims) == list or (type(ims) == np.ndarray and len(s) < 4):
            n = len(ims)

            if nrows is None and ncols is not None:
                nrows = int(np.ceil(1.0 * n / ncols))

            if ncols is None and nrows is not None:
                ncols = int(np.ceil(1.0 * n / nrows))

            # --- reshape file list into 2D array
            rest = n % nrows
            for i in range((nrows - rest) % nrows):
                ims.append(None)

            ims_old = np.copy(ims)
            ims = [[None] * ncols for _ in range(nrows)]

            for r in range(nrows):
                for c in range(ncols):
                    i = ncols * r + c
                    ims[r][c] = ims_old[i]

        elif type(ims) == np.ndarray:

            if len(s) > 1:
                if nrows is None:
                    nrows = s[0]
                if ncols is None:
                    ncols = s[1]
                n = nrows * ncols

        else:
            n = 1
            nrows = 1
            ncols = 1

            ims = [[ims]]

    if heads is None and files is not None:
        heads = [[None] * ncols for _ in range(nrows)]
        for r in range(nrows):
            for c in range(ncols):
                if files[r][c] != None:
                    if files[r][c] != "":
                        heads[r][c] = fits.getheader(files[r][c], ext=0)
#    print(np.shape(heads), len(np.shape(heads)))
    if heads is not None:
        if len(np.shape(heads)) == 1:
            heads = np.full((nrows, ncols), heads, dtype=object)

    if pfovs is None and heads is not None:
        pfovs = [[None] * ncols for _ in range(nrows)]
        for r in range(nrows):
            for c in range(ncols):
                if heads[r][c] != None:
                    if heads[r][c] != "":
                        if "PFOV" in heads[r][c]:
                            pfovs[r][c] = float(heads[r][c]['PFOV'])
                        elif "HIERARCH ESO INS PFOV" in heads[r][c]:
                            pfovs[r][c] = float(
                                heads[r][c]["HIERARCH ESO INS PFOV"])
                        elif "CDELT1" in heads[r][c]:
                            pfovs[r][c] = np.abs(heads[r][c]["CDELT1"]) * 3600
        if len(pfovs) == 0:
            pfovs = None

    else:
        pfovs = reshape_input_param(pfovs, nrows, ncols)

    scale = reshape_input_param(scale, nrows, ncols)
    permin = reshape_input_param(permin, nrows, ncols)
    permax = reshape_input_param(permax, nrows, ncols)
    absmin = reshape_input_param(absmin, nrows, ncols)
    absmax = reshape_input_param(absmax, nrows, ncols)
    norm = reshape_input_param(norm, nrows, ncols)
    smooth = reshape_input_param(smooth, nrows, ncols)
    titles = reshape_input_param(titles, nrows, ncols)
    titcols = reshape_input_param(titcols, nrows, ncols)
    texts = reshape_input_param(texts, nrows, ncols)
    textcols = reshape_input_param(textcols, nrows, ncols)
    subtitles = reshape_input_param(subtitles, nrows, ncols)
    subtitcols = reshape_input_param(subtitcols, nrows, ncols)
    sbarlength = reshape_input_param(sbarlength, nrows, ncols)

    if lines is not None:
        lines = reshape_input_param(lines, nrows, ncols)

    if verbose:

        #        ny, nx = ims[0].shape
        print("MAKE_GALLERY: nrows, ncols: ", nrows, ncols)
        print("MAKE_GALLERY: n: ", n)
        print("MAKE_GALLERY: pfovs: ", pfovs)
        print("MAKE_GALLERY: scales: ", scale)
#       print("MAKE_GALLERY: nx, ny: ", nx, ny)

# --- set up the plotting configuration to use latex

    if latex:
        mpl.rcdefaults()

        mpl.rc(
            'font', **{
                'family': 'sans-serif',
                'serif': ['Computer Modern Serif'],
                'sans-serif': ['Helvetica'],
                'size': latexfontsize,
                'weight': 500,
                'variant': 'normal'
            })

        mpl.rc('axes', **{'labelweight': 'normal', 'linewidth': 1.5})
        mpl.rc('ytick', **{'major.pad': 8, 'color': 'k'})
        mpl.rc('xtick', **{'major.pad': 8})
        mpl.rc(
            'mathtext', **{
                'default': 'regular',
                'fontset': 'cm',
                'bf': 'monospace:bold'
            })

        mpl.rc('text', **{'usetex': True})
        mpl.rc('text.latex',preamble=r'\usepackage{cmbright},\usepackage{relsize},'+\
                                    r'\usepackage{upgreek}, \usepackage{amsmath}'+\
                                    r'\usepackage{bm}')

        mpl.rc('contour', **{'negative_linestyle': 'solid'})  # dashed | solid

    plt.clf()

    # 14.17
    # 6.93
    if pwidth == None:
        if papercol == 1:
            pwidth = 6.93
        elif papercol == 2:
            pwidth = 14.17
        else:
            pwidth = 7 * papercol

    subpwidth = pwidth / ncols

    if verbose:
        print("MAKE_GALLERY: pwidth, subpwidth: ", pwidth, subpwidth)

    fig = plt.figure(figsize=(pwidth, nrows * subpwidth))
    # fig.subplots_adjust(bottom=0.2)
    # fig.subplots_adjust(left=0.2)

    if inv:
        cmap = cmap + '_r'

    grid = ImageGrid(fig,
                     111,
                     nrows_ncols=(nrows, ncols),
                     axes_pad=axes_pad,
                     aspect=True)

    handles = []
    ahandles = []
    vmin = np.zeros((nrows, ncols))
    vmax = np.zeros((nrows, ncols))
    amin = np.zeros((nrows, ncols))
    amax = np.zeros((nrows, ncols))

    # --- main plotting loop
    for r in range(nrows):
        for c in range(ncols):

            i = ncols * r + c

            if verbose:
                print("MAKE_GALLERY: r,c,i", r, c, i)
                print(" - smooth: ", smooth[r][c])

            if np.shape(ims[r][c]) == ():
                continue

            if view_as is not None:
                view_px = view_as / pfovs[r][c]
                im = _crop_image(ims[r][c], box=np.ceil(view_px / 2.0) * 2 + 2)
            else:
                im = np.copy(ims[r][c])

            #print(r,c, np.nanmax(ims[r][c]), np.nanmax(im))

            if smooth[r][c]:
                im = gaussian_filter(im, sigma=smooth[r][c], mode='nearest')

            if norm[r][c]:
                im[0, 0] = np.nanmax(ims[int(norm[r][c])])

            # --- adjust the cut levels
            if permin[r][c]:
                vmin[r][c] = np.nanpercentile(im, float(permin[r][c]))
            else:
                vmin[r][c] = np.nanmin(im)

            if permax[r][c]:
                vmax[r][c] = np.nanpercentile(im, float(permax[r][c]))
            else:
                vmax[r][c] = np.nanmax(im)

            if absmax[r][c] is not None:
                vmax[r][c] = absmax[r][c]

            if absmin[r][c] is not None:
                vmin[r][c] = absmin[r][c]

            sim = np.copy(im)
            sim[im < vmin[r][c]] = vmin[r][c]
            sim[im > vmax[r][c]] = vmax[r][c]

            if replace_NaNs is not None:
                idn = np.nonzero(np.isnan(sim))
                #            print(i, len(idn), idn[0])
                if len(idn) == 0:
                    continue
                elif replace_NaNs == "min":
                    sim[np.ix_(idn[0], idn[1])] = vmin[r][c]
                elif replace_NaNs == "max":
                    sim[np.ix_(idn[0], idn[1])] = vmax[r][c]
                else:
                    sim[np.ix_(idn[0], idn[1])] = replace_NaNs

            # --- logarithmic scaling?
            if scale[r][c] == "log":
                sim = np.log10(1000.0 * (sim - vmin[r][c]) /
                               (vmax[r][c] - vmin[r][c]) + 1)

            if verbose:
                print("MAKE_GALLERY: scale[r][c]", scale[r][c])
                print("MAKE_GALLERY: vmin[r][c], vmax[r][c]", vmin[r][c],
                      vmax[r][c])

            handle = grid[i].imshow(sim,
                                    cmap=cmap,
                                    origin='lower',
                                    interpolation='nearest')

            ny, nx = im.shape

            if verbose:
                print("ny, nx:", ny, nx)

            if pfovs is not None:

                xmin = (nx / 2 - 1) * pfovs[r][c]
                xmax = -nx / 2 * pfovs[r][c]
                ymin = -ny / 2 * pfovs[r][c]
                ymax = (ny / 2 - 1) * pfovs[r][c]
                if xlabel is None:
                    xlabel = 'RA offset ["]'
                if ylabel is None:
                    ylabel = 'DEC offset ["]'

            else:

                xmax = (nx / 2 - 1)
                xmin = -nx / 2
                ymin = -ny / 2
                ymax = (ny / 2 - 1)
                if xlabel is None:
                    xlabel = 'x offset [px]'
                if ylabel is None:
                    ylabel = 'y offset [px]'

            # print(pfovs[r][c])
            # pdb.set_trace()
            # set the extent of the image
            extent = [xmin, xmax, ymin, ymax]

            if verbose:
                print("extent:", extent)
                print("x/ylabel:", xlabel, ylabel)

            handle.set_extent(extent)

            # --- optionally overplot a second map using transparency
            if alphamap is not None:

                uim = np.copy(alphamap[r][c]['im'])

                # --- determine the levels
                if alphamap[r][c]['min'] is None:
                    amin[r][c] = np.nanmin(im)
                elif str(alphamap[r][c]['min']) == 'vmax':
                    amin[r][c] = vmax[r][c]
                elif str(alphamap[r][c]['min']) == 'vmin':
                    amin[r][c] = vmin[r][c]
                else:
                    amin[r][c] = np.nanpercentile(im,
                                                  float(alphamap[r][c]['min']))

                if alphamap[r][c]['max'] is None:
                    amax[r][c] = np.nanmax(im)
                elif str(alphamap[r][c]['max']) == 'vmax':
                    amax[r][c] = vmax[r][c]
                elif str(alphamap[r][c]['max']) == 'vmax':
                    amax[r][c] = vmin[r][c]
                else:
                    amax[r][c] = np.nanpercentile(im,
                                                  float(alphamap[r][c]['max']))

                uim[uim < amin[r][c]] = amin[r][c]
                uim[uim > amax[r][c]] = amax[r][c]

                cmap2 = _create_alpha_colmap(alphamap[r][c]['mincolor'],
                                             alphamap[r][c]['maxcolor'],
                                             alphamap[r][c]['minalpha'],
                                             alphamap[r][c]['maxalpha'])

                if alphamap[r][c]['log']:
                    unorm = LogNorm()
    #                uim = np.log10(1000.0 * (uim - amin[r][c]) /
    #                          (amax[r][c] - amin[r][c]) + 1)

                else:
                    unorm = None

                ahandle = grid[i].imshow(uim,
                                         cmap=cmap2,
                                         origin='lower',
                                         interpolation='nearest',
                                         norm=unorm)

                ahandle.set_extent(extent)
                ahandles.append(ahandle)

            # --- optionally draw contours
            if contours is not None:

                # --- determine the levels
                if contours[r][c]['min'] is None:
                    cmin = np.nanmin(im)
                elif str(contours[r][c]['min']) == 'vmax':
                    cmin = vmax[r][c]
                elif str(contours[r][c]['min']) == 'vmin':
                    cmin = vmin[r][c]
                else:
                    cmin = contours[r][c]['min']

                if contours[r][c]['max'] is None:
                    cmax = np.nanmax(im)
                elif str(contours[r][c]['max']) == 'vmax':
                    cmax = vmax[r][c]
                elif str(contours[r][c]['max']) == 'vmax':
                    cmax = vmin[r][c]
                else:
                    cmax = contours[r][c]['max']

                if contours[r][c]['nstep'] is None:
                    nstep = 10
                else:
                    nstep = contours[r][c]['nstep']

                if contours[r][c]['stepsize'] is None:
                    stepsize = int(cmax - cmin / nstep)
                else:
                    stepsize = contours[r][c]['stepsize']

                levels = np.arange(cmin, cmax, stepsize)

                cont = grid[i].contour(im,
                                       levels=levels,
                                       origin='lower',
                                       colors=contours[r][c]['color'],
                                       linewidth=contours[r][c]['linewidth'],
                                       extent=extent)

                grid[i].clabel(cont,
                               levels[1::contours[r][c]['labelinterval']],
                               inline=1,
                               fontsize=contours[r][c]['labelsize'],
                               fmt=contours[r][c]['labelfmt'])

            if pfovs is not None and view_as is not None:
                xmin = 0.5 * view_as
                xmax = -0.5 * view_as
                ymin = -0.5 * view_as
                ymax = 0.5 * view_as

            elif view_px is not None:
                xmin = -0.5 * view_px
                xmax = 0.5 * view_px
                ymin = -0.5 * view_px
                ymax = 0.5 * view_px

            if verbose:
                print("xmin,xmax,ymin,ymax:", xmin, xmax, ymin, ymax)

            grid[i].set_ylim(ymin, ymax)
            grid[i].set_xlim(xmin, xmax)

            # --- optionally draw some lines
            if lines is not None:

                nlin = len(lines[r][c]["length"])
                for l in range(nlin):

                    # --- provided unit for lines is arcsec?
                    #                print(lines[r][c]["length"][l])
                    if lines[r][c]["unit"][l] == "px":
                        lines[r][c]["length"][l] *= pfovs[r][c]
                        lines[r][c]["xoff"][l] *= pfovs[r][c]
                        lines[r][c]["yoff"][l] *= pfovs[r][c]

    #                print(lines[r][c]["length"][l])

                    ang_rad = (90 - lines[r][c]["pa"][l]) * np.pi / 180.0
                    hlength = 0.5 * lines[r][c]["length"][l]

                    lx = [
                        lines[r][c]["xoff"][l] - hlength * np.cos(ang_rad),
                        lines[r][c]["xoff"][l] + hlength * np.cos(ang_rad)
                    ]

                    ly = [
                        lines[r][c]["yoff"][l] - hlength * np.sin(ang_rad),
                        lines[r][c]["yoff"][l] + hlength * np.sin(ang_rad)
                    ]

                    grid[i].plot(lx,
                                 ly,
                                 color=lines[r][c]["color"][l],
                                 linestyle=lines[r][c]["style"][l],
                                 marker='',
                                 linewidth=lines[r][c]["thick"][l],
                                 alpha=lines[r][c]["alpha"][l])

                    #                theta1 = lines[r][c]["pa"][l] - 10
                    #                theta2 = lines[r][c]["pa"][l] + 10

                    # --- angular error bars
                    if lines[r][c]["pa_err"][l] > 0:

                        theta1 = 90 - lines[r][c]["pa"][l] - lines[r][c][
                            "pa_err"][l]
                        theta2 = 90 - lines[r][c]["pa"][l] + lines[r][c][
                            "pa_err"][l]

                        parc = Arc(
                            (lines[r][c]["xoff"][l], lines[r][c]["yoff"][l]),
                            width=2 * hlength,
                            height=2 * hlength,
                            color=lines[r][c]["color"][l],
                            linewidth=lines[r][c]["thick"][l],
                            angle=0,
                            theta1=theta1,
                            theta2=theta2,
                            alpha=lines[r][c]["alpha"][l])

                        grid[i].add_patch(parc)

                        theta1 = -90 - lines[r][c]["pa"][l] - lines[r][c][
                            "pa_err"][l]
                        theta2 = -90 - lines[r][c]["pa"][l] + lines[r][c][
                            "pa_err"][l]

                        parc = Arc(
                            (lines[r][c]["xoff"][l], lines[r][c]["yoff"][l]),
                            width=2 * hlength,
                            height=2 * hlength,
                            color=lines[r][c]["color"][l],
                            linewidth=lines[r][c]["thick"][l],
                            angle=0,
                            theta1=theta1,
                            theta2=theta2,
                            alpha=lines[r][c]["alpha"][l])

                        grid[i].add_patch(parc)

            # --- optionally draw some symbols
            if plotsym is not None:
                #             if hasattr(plotsym[r][c]['x'], "__len__"):
                nsymplots = len(plotsym[r][c]['x'])
                print('nsymplots: ', nsymplots)
                for j in range(nsymplots):
                    #                grid[i].plot(plotsym[r][c]['x'][j], plotsym[r][c]['y'][j],
                    #                             marker=plotsym[r][c]['marker'][j],
                    #                             markersize=plotsym[r][c]['markersize'][j],
                    #                             markeredgewidth=plotsym[r][c]['markeredgewidth'][j],
                    #                             fillstyle=plotsym[r][c]['fillstyle'][j],
                    #                             alpha=plotsym[r][c]['alpha'][j],
                    #                             markerfacecolor=plotsym[r][c]['markerfacecolor'][j],
                    #                             markeredgecolor=plotsym[r][c]['markeredgecolor'][j])

                    if plotsym[r][c]['marker'][j] == 'ch':
                        marker = _crosshair_marker()
                    elif plotsym[r][c]['marker'][j] == 'ch45':
                        marker = _crosshair_marker(pa=45)
                    else:
                        marker = plotsym[r][c]['marker'][j]

                    grid[i].scatter(plotsym[r][c]['x'][j],
                                    plotsym[r][c]['y'][j],
                                    marker=marker,
                                    s=plotsym[r][c]['size'][j],
                                    linewidths=plotsym[r][c]['linewidth'][j],
                                    alpha=plotsym[r][c]['alpha'][j],
                                    facecolor=plotsym[r][c]['color'][j],
                                    edgecolors=plotsym[r][c]['edgecolor'][j],
                                    linestyle=plotsym[r][c]['linestyle'][j])

                    if plotsym[r][c]['label'][j] is not None:
                        grid[i].text(
                            plotsym[r][c]['x'][j][0] +
                            plotsym[r][c]['labelxoffset'][j],
                            plotsym[r][c]['y'][j][0] +
                            plotsym[r][c]['labelyoffset'][j],
                            plotsym[r][c]['label'][j],
                            color=plotsym[r][c]['labelcolor'][j],
                            verticalalignment=plotsym[r][c]['labelvalign'][j],
                            horizontalalignment=plotsym[r][c]['labelhalign']
                            [j])

            # --- ticks
            if majtickinterval is None:

                majorLocator = MultipleLocator(1)

                if ymax - ymin < 2:
                    majorLocator = MultipleLocator(0.5)
                    minorLocator = AutoMinorLocator(5)

                elif (ymax - ymin > 10) & (ymax - ymin <= 20):
                    majorLocator = MultipleLocator(2)

                elif (ymax - ymin > 20) & (ymax - ymin <= 100):
                    majorLocator = MultipleLocator(5)

                elif ymax - ymin > 100:
                    majorLocator = MultipleLocator(10)

            else:
                majorLocator = MultipleLocator(majtickinterval)

            minorLocator = AutoMinorLocator(10)
            if ymax - ymin < 2:
                minorLocator = AutoMinorLocator(5)

            grid[i].xaxis.set_minor_locator(minorLocator)
            grid[i].xaxis.set_major_locator(majorLocator)
            grid[i].yaxis.set_minor_locator(minorLocator)
            grid[i].yaxis.set_major_locator(majorLocator)
            grid[i].yaxis.set_tick_params(width=1.5, which='both')
            grid[i].xaxis.set_tick_params(width=1.5, which='both')
            grid[i].xaxis.set_tick_params(length=6)
            grid[i].yaxis.set_tick_params(length=6)
            grid[i].xaxis.set_tick_params(length=3, which='minor')
            grid[i].yaxis.set_tick_params(length=3, which='minor')

            # --- text
            if titles[r][c]:
                #   pdb.set_trace()

                if verbose:
                    print("titles[r][c]", titles[r][c])

                grid[i].set_title(titles[r][c],
                                  x=titx,
                                  y=tity,
                                  color=titcols[r][c],
                                  verticalalignment=titvalign,
                                  horizontalalignment=tithalign)

            if subtitles[r][c]:

                if verbose:
                    print("subtitles[r][c]", subtitles[r][c])

                grid[i].text(subtitx,
                             subtity,
                             subtitles[r][c],
                             color=subtitcols[r][c],
                             transform=grid[i].transAxes,
                             verticalalignment=subtitvalign,
                             horizontalalignment=subtithalign)

            if texts[r][c]:
                if verbose:
                    print("texts[r][c]", texts[r][c])
                grid[i].text(textx,
                             texty,
                             texts[r][c],
                             color=textcols[r][c],
                             transform=grid[i].transAxes,
                             verticalalignment=textvalign,
                             horizontalalignment=texthalign,
                             fontsize=textsize)

            # --- scale bar for size comparison
            if sbarlength[r][c]:
                if sbarunit == 'px':
                    sbarlength[r][c] = sbarlength[r][c] * pfovs[r][c]
                sx = [
                    sbarpos[0] - 0.5 * sbarlength[r][c],
                    sbarpos[0] + 0.5 * sbarlength[r][c]
                ]
                sy = [sbarpos[1], sbarpos[1]]
                grid[i].plot(sx,
                             sy,
                             linewidth=sbarthick,
                             color=sbarcol,
                             transform=grid[i].transAxes)

            handles.append(handle)

    # --- get the extent of the largest box containing all the axes/subplots
    extents = np.array([bla.get_position().extents for bla in grid])
    bigextents = np.empty(4)
    bigextents[:2] = extents[:, :2].min(axis=0)
    bigextents[2:] = extents[:, 2:].max(axis=0)

    # --- distance between the external axis and the text
    if xlabelpad is None and papercol == 2:
        xlabelpad = 0.03
    elif xlabelpad is None and papercol == 1:
        xlabelpad = 0.15
    elif xlabelpad is None:
        xlabelpad = 0.15 / papercol * 0.5

    if ylabelpad is None and papercol == 2:
        ylabelpad = 0.055
    elif ylabelpad is None and papercol == 1:
        ylabelpad = 0.1
    elif ylabelpad is None:
        ylabelpad = 0.1 / papercol

    if verbose:
        print("xlabelpad,ylabelpad: ", xlabelpad, ylabelpad)

    # --- text to mimic the x and y label. The text is positioned in
    #     the middle
    fig.text((bigextents[2] + bigextents[0]) / 2,
             bigextents[1] - xlabelpad,
             xlabel,
             horizontalalignment='center',
             verticalalignment='bottom')

    fig.text(bigextents[0] - ylabelpad, (bigextents[3] + bigextents[1]) / 2,
             ylabel,
             rotation='vertical',
             horizontalalignment='left',
             verticalalignment='center')

    # --- now colorbar business:
    if colorbar is not None:
        # first draw the figure, such that the axes are positionned
        fig.canvas.draw()
        #create new axes according to coordinates of image plot
        trans = fig.transFigure.inverted()

    if colorbar == 'row':
        for i in range(nrows):
            g = grid[i * ncols - 1].bbox.transformed(trans)

            if alphamap[ncols, i] is not None:

                height = 0.5 * g.height
                pos = [
                    g.x1 + g.width * 0.02, g.y0 + height, g.width * cbarwidth,
                    height
                ]
                cax = fig.add_axes(pos)

                if alphamap[ncols, i]['log']:
                    formatter = LogFormatter(10, labelOnlyBase=False)
                    cb = plt.colorbar(ticks=[
                        0.1, 0.2, 0.5, 1, 5, 10, 20, 50, 100, 200, 500, 1000,
                        2000, 5000, 10000
                    ],
                                      format=formatter,
                                      mappable=ahandles[i * ncols - 1],
                                      cax=cax)
                else:
                    cb = plt.colorbar(mappable=ahandles[i * ncols - 1],
                                      cax=cax)

                #majorLocator = MultipleLocator(alphamap[i*ncols - 1]['cbartickinterval'])
                #cb.ax.yaxis.set_major_locator(majorLocator)

#                if alphamap[i*ncols - 1]['log']:
#                    oldlabels = cb.ax.get_yticklabels()
#                    oldlabels = np.array([float(x.get_text().replace('$','')) for x in oldlabels])
#                    newlabels = 0.001*(10.0**oldlabels -1)*(amax[i*ncols - 1] - amin[i*ncols - 1]) + amin[i*ncols - 1]
#                    newlabels = [str(x)[:4] for x in newlabels]
#                    cb.ax.set_yticklabels(newlabels)

            else:

                height = g.height

            # --- pos = [left, bottom, width, height]
            pos = [g.x1 + g.width * 0.02, g.y0, g.width * cbarwidth, height]
            cax = fig.add_axes(pos)

            if (vmax[ncols, i] - vmin[ncols, i]) < 10:
                decimals = 1
            else:
                decimals = 0
            if (vmax[ncols, i] - vmin[ncols, i]) < 1:
                decimals = 2

            ticks = np.arange(vmin[ncols, i] + 1.0 / 10**decimals,
                              vmax[ncols, i], cbartickinterval)

            ticks = np.round(ticks, decimals=decimals)

            cb = plt.colorbar(ticks=ticks,
                              mappable=handles[i * ncols - 1],
                              cax=cax)

            # majorLocator = MultipleLocator(cbartickinterval)
            # cb.ax.yaxis.set_major_locator(majorLocator)

            if scale[ncols, i] == "log":
                oldlabels = cb.ax.get_yticklabels()
                oldlabels = np.array(
                    [float(x.get_text().replace('$', '')) for x in oldlabels])
                newlabels = 0.001 * (10.0**oldlabels - 1) * (
                    vmax[i * ncols - 1] - vmin[ncols, i]) + vmin[ncols, i]
                newlabels = [str(x)[:4] for x in newlabels]
                cb.ax.set_yticklabels(newlabels)

    elif colorbar == 'single':

        gb = grid[-1].bbox.transformed(trans)
        gt = grid[0].bbox.transformed(trans)

        if alphamap is not None:
            height = 0.5 * (gt.y1 - gb.y0)
            pos = [
                gb.x1 + gb.width * 0.02, gb.y0 + height, gb.width * cbarwidth,
                height
            ]
            cax = fig.add_axes(pos)
            # cb = plt.colorbar(mappable=ahandles[0], cax=cax)

            #               majorLocator = MultipleLocator(alphamap[0]['cbartickinterval'])
            #               cb.ax.yaxis.set_major_locator(majorLocator)

            #                if alphamap[0]['log']:
            #                   oldlabels = cb.ax.get_yticklabels()
            #                   oldlabels = np.array([float(x.get_text().replace('$','')) for x in oldlabels])
            #                   newlabels = 0.001*(10.0**oldlabels -1)*(amax[0] - amin[0]) + amin[0]
            #                   newlabels = [str(x)[:4] for x in newlabels]
            #                   cb.ax.set_yticklabels(newlabels)
            if alphamap[0]['log']:
                formatter = LogFormatter(10, labelOnlyBase=False)
                cb = plt.colorbar(ticks=[
                    0.1, 0.2, 0.5, 1, 5, 10, 20, 50, 100, 200, 500, 1000, 2000,
                    5000, 10000
                ],
                                  format=formatter,
                                  mappable=ahandles[0],
                                  cax=cax)
            else:
                cb = plt.colorbar(mappable=ahandles[0], cax=cax)

        else:
            height = (gt.y1 - gb.y0)

        # --- pos = [left, bottom, width, height]
        pos = [gb.x1 + gb.width * 0.02, gb.y0, gb.width * cbarwidth, height]
        cax = fig.add_axes(pos)
        cb = plt.colorbar(mappable=handles[0], cax=cax)

        # majorLocator = MultipleLocator(cbartickinterval)
        # cb.ax.yaxis.set_major_locator(majorLocator)

        if scale[0, 0] == "log":
            oldlabels = cb.ax.get_yticklabels()
            oldlabels = np.array(
                [float(x.get_text().replace('$', '')) for x in oldlabels])
            newlabels = 0.001 * (10.0**oldlabels - 1) * (vmax[0] -
                                                         vmin[0]) + vmin[0]
            newlabels = [str(x)[:4] for x in newlabels]
            cb.ax.set_yticklabels(newlabels)

    if cbarlabel is not None:

        if cbarlabelpad is None:
            if papercol == 1:
                cbarlabelpad = 0.15
            else:
                cbarlabelpad = 0.08

        fig.text(bigextents[2] + cbarlabelpad,
                 (bigextents[3] + bigextents[1]) / 2,
                 cbarlabel,
                 rotation='vertical',
                 horizontalalignment='right',
                 verticalalignment='center')

    if outname:
        plt.savefig(outname, bbox_inches='tight', pad_inches=0.1)
        plt.close(fig)

    else:
        plt.show()

    if latex:
        mpl.rcdefaults()
Beispiel #17
0
def create_channelmap(raster=None,
                      contour=None,
                      clevels=None,
                      zeropoint=0.,
                      channels=None,
                      ncols=4,
                      vrange=None,
                      vscale=None,
                      show=True,
                      pdfname=None,
                      cbarlabel=None,
                      cmap='RdYlBu_r',
                      beam=True,
                      **kwargs):
    """ Running this function will create a quick channel map of the Qube.
    One can either plot the contours or the raster image or both. This program
    can be used as a basis for a more detailed individual plot which can take
    better care of whitespace, etc. The following keywords are valid:

    raster:     qube object used to generate the raster image, leave blank or
                'None' if not desired.
    contour:    qube object used to generate the contours, leave blank or
                'None' if not desired.
    clevels:    nested lists of contour levels to draw, list should be the same
                length as the spectral dimension of the contour qube, if a
                single list is given assumes that the contours are the same
                for all channels.
    zeropoint:  Optional shift in velocities compared to the Restfrq keyword in
                the header of the Qube.
    ncols:      number of columns to plot
    vrange:     range of the z-axis (for consistency across panels) if None,
                will take minimum maximum of the Qube
    vscale:     division of the colorbar (if none will make ~5 divisions)
    cbarlabel:  if set, label of the color bar
    cmap:       colormap to use for the plotting
    pdfname:    if set, the name of the pdf file to which the image was saved

    Returns:
        fig
    """

    # generate a temporary qube from the data
    if raster is not None:
        tqube = raster
    elif contour is not None:
        tqube = contour
    else:
        raise ValueError('Need to define either a contour or raster image.')

    # first genererate the velocity array
    VelArr = tqube._getvelocity_() - zeropoint

    # define the number of channels, rows (and columns = already defined)
    if channels is None:
        channels = np.arange(tqube.shape[0])
    nrows = np.ceil(len(channels) / ncols).astype(int)

    # define the range of the z-axis (if needed)
    if vrange is None:
        vrange = [np.nanmin(tqube.data), np.nanmax(tqube.data)]
    # now generate a grid with the specified number of columns
    fig = plt.figure(1, (8., 8. / ncols * nrows))
    grid = ImageGrid(fig,
                     111,
                     nrows_ncols=(nrows, ncols),
                     axes_pad=0.,
                     cbar_mode='single',
                     cbar_location='right',
                     share_all=True)

    # now loop over the channels
    for idx, chan in enumerate(channels):

        # get the string value of the velocity
        VelStr = str(int(round(VelArr[chan]))) + ' km s$^{-1}$'

        # get the boolean of the beam (bottom left figure only)
        if (idx % ncols == 0) and (idx // ncols == int(nrows) - 1) and beam:
            beambool = True
        else:
            beambool = False

        # plot the individual channels
        if raster is not None:
            rasterimage = raster.get_slice(zindex=(chan, chan + 1))
        else:
            rasterimage = None
        if contour is not None:
            contourimage = contour.get_slice(zindex=(chan, chan + 1))
            # also get the contour levels
            if clevels is None:
                raise ValueError('Set the contour levels using the clevels ' +
                                 'keyword')
            elif type(clevels[0]) is list or type(clevels[0]) is np.ndarray:
                clevel = clevels[chan]
            else:
                clevel = clevels
        else:
            contourimage = None
            clevel = clevels

        standardfig(raster=rasterimage,
                    contour=contourimage,
                    clevels=clevel,
                    newplot=False,
                    fig=fig,
                    ax=grid[idx],
                    cbar=False,
                    beam=beambool,
                    vrange=vrange,
                    text=[VelStr],
                    cmap=cmap,
                    **kwargs)

    # now do the color bar
    norm = mpl.colors.Normalize(vmin=vrange[0], vmax=vrange[1])
    cmapo = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    cmapo.set_array([])
    if vscale is None:
        vscale = (vrange[1] - vrange[0]) / 5.
    cbr = plt.colorbar(cmapo,
                       cax=grid.cbar_axes[0],
                       ticks=np.arange(-10, 10) * vscale)

    # label of color bar
    if cbarlabel is not None:
        cbr.ax.set_ylabel(cbarlabel, labelpad=-1)

    if pdfname is not None:
        plt.savefig(pdfname, format='pdf', dpi=300)
    elif show:
        plt.show()
    else:
        pass

    return fig, grid, channels
def attention_epoch_plot(net,
                         folder_name,
                         source_images,
                         logged=False,
                         width=5,
                         device=torch.device('cpu'),
                         layer_name_base='attention',
                         layer_no=2,
                         cmap_name='magma',
                         figsize=(100, 100)):
    """
    Function for plotting clean grid of attention maps as they
    develop throughout the learning stages.
    Args:
        The attention map data,
        original images of sources
        number of unique sources,
        if you want your image logged,
        number of output attentions desired (sampled evenly accross available space)
        epoch labels of when the images were extracted
    Out:
        plt of images concatenated in correct fashion
    """

    # cmap_name and RGB potential
    if cmap_name == 'RGB':
        mean_ = False
        cmap_name = 'magma'

    # Generate attention maps for each available Epoch
    attention_maps_temp, og_attention_maps, epoch_labels = AttentionImagesByEpoch(
        source_images,
        folder_name,
        net,
        epoch=2000,
        device=device,
        layer_name_base=layer_name_base,
        layer_no=layer_no,
        mean=mean_)

    # Extract terms to be used in plotting
    sample_number = source_images.shape[0]
    no_saved_attentions_epochs = np.asarray(
        attention_maps_temp).shape[0] // sample_number
    attentions = np.asarray(attention_maps_temp)
    imgs = []
    labels = []
    width_array = range(no_saved_attentions_epochs)

    if width <= no_saved_attentions_epochs:
        width_array = np.linspace(0,
                                  no_saved_attentions_epochs - 1,
                                  num=width,
                                  dtype=np.int32)
    else:
        width = no_saved_attentions_epochs

    # Prepare the selection of images in the correct order as to be plotted reasonably (and prepare epoch labels)
    for j in range(sample_number):
        if logged:
            imgs.append(np.exp(source_images[j].squeeze()))
        else:
            imgs.append(source_images[j].squeeze())
        for i in width_array:
            #print(sample_number,i,j)
            imgs.append(attention_maps_temp[sample_number * i + j])
            try:
                labels[width - 1]
            except:
                labels.append(epoch_labels[sample_number * i])

    # Define the plot of the grid of images
    fig = plt.figure(figsize=figsize)
    grid = ImageGrid(
        fig,
        111,
        nrows_ncols=(sample_number, width + 1),
        axes_pad=0.02,  # pad between axes in inch.
    )
    for idx, (ax, im) in enumerate(zip(grid, imgs)):
        # Transpose for RGB image
        if im.shape[0] == 3:
            im = im.transpose(1, 2, 0)
        # Plot image
        if logged:
            ax.imshow(np.log(im), cmap=cmap_name)
        else:
            ax.imshow(im, cmap=cmap_name)
        # Plot contour if image is source image
        if idx % (width + 1) == 0:
            ax.contour(im, 1, cmap='cool', alpha=0.5)
        ax.axis('off')
    print(
        f'Source images followed by their respective averaged attention maps at epochs:\n{labels}'
    )
    plt.show()
Beispiel #19
0
def modeldata_comparison(cube, pdffile=None, **kwargs):
    """This will make a six-panel plot for providing an easy overview of the
    the model and the data. Panels are the zeroth moment for the model,
    the data and the residual, the first moment for the model and the data,
    and the residual in the velocity field.

    inputs:
    -------
    cube: Qubefit object that holds the data and the model

    keywords:
    ---------
    pdffile (string|default: None): If set, will save the image to a pdf file

    showmask (Bool|default: False): If set, it will show a contour of the mask
        used in creating the fit (stored in cube.maskarray).
    """

    # create the figure plots
    fig = plt.figure(1, (8., 8.))
    grid = ImageGrid(fig, 111, nrows_ncols=(2, 3), axes_pad=0)

    # create the data, model and residual cubes
    dqube = dc(cube)
    mqube = dc(cube)
    mqube.data = cube.model
    rqube = dc(cube)
    rqube.data = cube.data - cube.model

    # moment-zero data
    dMom0 = dqube.calculate_moment(moment=0)
    Mom0sig = (np.sqrt(np.nansum(cube.variance[:, 0, 0])) *
               cube.__get_velocitywidth__())
    clevels = np.insert(np.arange(3, 30, 3), 0, np.arange(-30, 0, 3)) * Mom0sig
    vrangemom = [-3 * Mom0sig, 11 * Mom0sig]
    mask = dMom0.mask_region(value=Mom0sig * 3, applymask=False)
    standardfig(raster=dMom0,
                contour=dMom0,
                clevels=clevels,
                ax=grid[0],
                fig=fig,
                vrange=vrangemom,
                cmap='RdYlBu_r',
                text='Data',
                textprop=[dict(size=12)],
                **kwargs)

    # moment-zero model
    mMom0 = mqube.calculate_moment(moment=0)
    standardfig(raster=mMom0,
                contour=mMom0,
                clevels=clevels,
                ax=grid[1],
                fig=fig,
                vrange=vrangemom,
                cmap='RdYlBu_r',
                text='Model',
                textprop=[dict(size=12)],
                **kwargs)

    # moment-zero residual
    rMom0 = rqube.calculate_moment(moment=0)
    standardfig(raster=rMom0,
                contour=rMom0,
                clevels=clevels,
                ax=grid[2],
                fig=fig,
                vrange=vrangemom,
                cmap='RdYlBu_r',
                text='Residual',
                textprop=[dict(size=12)],
                **kwargs)

    # moment-one data
    dsqube = dqube.mask_region(mask=mask)
    dMom1 = dsqube.calculate_moment(moment=1)
    vrangemom1 = [0.95 * np.nanmin(dMom1.data), 0.95 * np.nanmax(dMom1.data)]
    standardfig(raster=dMom1,
                ax=grid[3],
                fig=fig,
                cmap='Spectral_r',
                vrange=vrangemom1,
                text='Data',
                textprop=[dict(size=12)],
                **kwargs)

    msqube = mqube.mask_region(mask=mask)
    mMom1 = msqube.calculate_moment(moment=1)
    standardfig(raster=mMom1,
                ax=grid[4],
                fig=fig,
                cmap='Spectral_r',
                vrange=vrangemom1,
                beam=False,
                text='Model',
                textprop=[dict(size=12)],
                **kwargs)

    rMom1 = dc(mMom1)
    rMom1.data = dMom1.data - mMom1.data
    standardfig(raster=rMom1,
                ax=grid[5],
                fig=fig,
                cmap='Spectral_r',
                vrange=vrangemom1,
                beam=False,
                text='Residual',
                textprop=[dict(size=12)],
                **kwargs)
Beispiel #20
0
for category in CATEGORIES:
    print('{}: {} images'.format(category, len(os.listdir(os.path.join(train_dir, category)))))
print("-" * 27)

# create a dataframe for dataset including filename_path, catagory and id
train = []
for category_id, category in enumerate(CATEGORIES):
    for file in os.listdir(os.path.join(train_dir, category)):
        train.append(['../data/train/{}/{}'.format(category, file), category, category_id])
        
train = pd.DataFrame(train, columns=['file', 'category','category_id'])


# Show the sample images
fig = plt.figure(1, figsize=(NumCatergories, NumCatergories))
grid = ImageGrid(fig, 111, nrows_ncols=(NumCatergories, NumCatergories), axes_pad=0.05)

i = 0
for category_id, category in enumerate(CATEGORIES):
    for filepath in train[train['category'] == category]['file'].values[:NumCatergories]:
        ax = grid[i]
        img = read_img(filepath)
        ax.imshow(img)
        ax.axis('off')
        if i % NumCatergories == NumCatergories - 1:
            ax.text(250, 112, filepath.split('/')[3], verticalalignment='center')
        i += 1

print("display samples")
plt.pause(3)
plt.savefig('../results/sample_images.png', transparent=True)

def get_demo_image():
    import numpy as np
    from matplotlib.cbook import get_sample_data
    f = get_sample_data("axes_grid/bivariate_normal.npy", asfileobj=False)
    z = np.load(f)
    # z is a numpy array of 15x15
    return z, (-3, 4, -4, 3)


F = plt.figure(1, (5.5, 3.5))
grid = ImageGrid(
    F,
    111,  # similar to subplot(111)
    nrows_ncols=(1, 3),
    axes_pad=0.1,
    add_all=True,
    label_mode="L",
)

Z, extent = get_demo_image()  # demo image

im1 = Z
im2 = Z[:, :10]
im3 = Z[:, 10:]
vmin, vmax = Z.min(), Z.max()
for i, im in enumerate([im1, im2, im3]):
    ax = grid[i]
    ax.imshow(im,
              origin="lower",
              vmin=vmin,
Beispiel #22
0
    def _savegif(
        self,
        stems: List[str],
        imgs: NDArray,
        masks: NDArray,
        reconstructed_imgs: NDArray,
        amaps: NDArray,
    ) -> None:

        os.mkdir("results")
        pbar = tqdm(enumerate(
            zip(stems, imgs, masks, reconstructed_imgs, amaps)),
                    desc="savegif")
        for i, (stem, img, mask, reconstructed_img, amap) in pbar:

            # How to get two subplots to share the same y-axis with a single colorbar
            # https://stackoverflow.com/a/38940369
            grid = ImageGrid(
                fig=plt.figure(figsize=(16, 4)),
                rect=111,
                nrows_ncols=(1, 4),
                axes_pad=0.15,
                share_all=True,
                cbar_location="right",
                cbar_mode="single",
                cbar_size="5%",
                cbar_pad=0.15,
            )

            grid[0].imshow(img, cmap="gray")
            grid[0].tick_params(labelbottom=False,
                                labelleft=False,
                                bottom=False,
                                left=False)
            grid[0].set_title("Input Image", fontsize=20)

            grid[1].imshow(reconstructed_img, cmap="gray")
            grid[1].tick_params(labelbottom=False,
                                labelleft=False,
                                bottom=False,
                                left=False)
            grid[1].set_title("Reconstructed Image", fontsize=20)

            grid[2].imshow(img, cmap="gray")
            grid[2].imshow(mask, alpha=0.3, cmap="Reds")
            grid[2].tick_params(labelbottom=False,
                                labelleft=False,
                                bottom=False,
                                left=False)
            grid[2].set_title("Ground Truth", fontsize=20)

            grid[3].imshow(img, cmap="gray")
            im = grid[3].imshow(amap, alpha=0.3, cmap="jet", vmin=0, vmax=1)
            grid[3].tick_params(labelbottom=False,
                                labelleft=False,
                                bottom=False,
                                left=False)
            grid[3].cax.toggle_label(True)
            grid[3].set_title("Anomaly Map", fontsize=20)

            plt.colorbar(im, cax=grid.cbar_axes[0])
            plt.savefig(f"results/{stem}.png", bbox_inches="tight")
            plt.close()

        # NOTE(inoue): The gif files converted by PIL or imageio were low-quality.
        #              So, I used the conversion command (ImageMagick) instead.
        subprocess.run("convert -delay 100 -loop 0 results/*.png result.gif",
                       shell=True)
Beispiel #23
0
    def show_image(self):
        self.fig.clf()
        stat_temp = self.get_activated_num()
        stat_temp = OrderedDict(
            sorted(six.iteritems(stat_temp), key=lambda x: x[0]))

        # Check if positions data is available. Positions data may be unavailable
        # (not recorded in HDF5 file) if experiment is has not been completed.
        # While the data from the completed part of experiment may still be used,
        # plotting vs. x-y or scatter plot may not be displayed.
        positions_data_available = False
        if 'positions' in self.data_dict.keys():
            positions_data_available = True

        # Create local copies of self.pixel_or_pos, self.scatter_show and self.grid_interpolate
        pixel_or_pos_local = self.pixel_or_pos
        scatter_show_local = self.scatter_show
        grid_interpolate_local = self.grid_interpolate

        # Disable plotting vs x-y coordinates if 'positions' data is not available
        if not positions_data_available:
            if pixel_or_pos_local:
                pixel_or_pos_local = 0  # Switch to plotting vs. pixel number
                logger.error(
                    "'Positions' data is not available. Plotting vs. x-y coordinates is disabled"
                )
            if scatter_show_local:
                scatter_show_local = False  # Switch to plotting vs. pixel number
                logger.error(
                    "'Positions' data is not available. Scatter plot is disabled."
                )
            if grid_interpolate_local:
                grid_interpolate_local = False  # Switch to plotting vs. pixel number
                logger.error(
                    "'Positions' data is not available. Interpolation is disabled."
                )

        low_lim = 1e-4  # define the low limit for log image
        plot_interp = 'Nearest'

        if self.scaler_data is not None:
            if np.count_nonzero(self.scaler_data) == 0:
                logger.warning('scaler is zero - scaling was not applied')
            elif len(self.scaler_data[self.scaler_data == 0]) > 0:
                logger.warning('scaler data has zero values')

        grey_use = self.color_opt

        ncol = int(np.ceil(np.sqrt(len(stat_temp))))
        try:
            nrow = int(np.ceil(len(stat_temp) / float(ncol)))
        except ZeroDivisionError:
            ncol = 1
            nrow = 1

        a_pad_v = 0.8
        a_pad_h = 0.5

        grid = ImageGrid(self.fig,
                         111,
                         nrows_ncols=(nrow, ncol),
                         axes_pad=(a_pad_v, a_pad_h),
                         cbar_location='right',
                         cbar_mode='each',
                         cbar_size='7%',
                         cbar_pad='2%',
                         share_all=True)

        def _compute_equal_axes_ranges(x_min, x_max, y_min, y_max):
            """
            Compute ranges for x- and y- axes of the plot. Make sure that the ranges for x- and y-axes are
            always equal and fit the maximum of the ranges for x and y values:
                  max(abs(x_max-x_min), abs(y_max-y_min))
            The ranges are set so that the data is always centered in the middle of the ranges

            Parameters
            ----------

            x_min, x_max, y_min, y_max : float
                lower and upper boundaries of the x and y values

            Returns
            -------

            x_axis_min, x_axis_max, y_axis_min, y_axis_max : float
                lower and upper boundaries of the x- and y-axes ranges
            """

            x_axis_min, x_axis_max, y_axis_min, y_axis_max = x_min, x_max, y_min, y_max
            x_range, y_range = abs(x_max - x_min), abs(y_max - y_min)
            if x_range > y_range:
                y_center = (y_max + y_min) / 2
                y_axis_max = y_center + x_range / 2
                y_axis_min = y_center - x_range / 2
            else:
                x_center = (x_max + x_min) / 2
                x_axis_max = x_center + y_range / 2
                x_axis_min = x_center - y_range / 2

            return x_axis_min, x_axis_max, y_axis_min, y_axis_max

        def _adjust_data_range_using_min_ratio(c_min,
                                               c_max,
                                               c_axis_range,
                                               *,
                                               min_ratio=0.01):
            """
            Adjust the range for plotted data along one axis (x or y). The adjusted range is
            applied to the 'extend' attribute of imshow(). The adjusted range is always greater
            than 'axis_range * min_ratio'. Such transformation has no physical meaning
            and performed for aesthetic reasons: stretching the image presentation of
            a scan with only a few lines (1-3) greatly improves visibility of data.

            Parameters
            ----------

            c_min, c_max : float
                boundaries of the data range (along x or y axis)
            c_axis_range : float
                range presented along the same axis

            Returns
            -------

            cmin, c_max : float
                adjusted boundaries of the data range
            """
            c_range = c_max - c_min
            if c_range < c_axis_range * min_ratio:
                c_center = (c_max + c_min) / 2
                c_new_range = c_axis_range * min_ratio
                c_min = c_center - c_new_range / 2
                c_max = c_center + c_new_range / 2
            return c_min, c_max

        for i, (k, v) in enumerate(six.iteritems(stat_temp)):

            data_dict = normalize_data_by_scaler(
                data_in=self.dict_to_plot[k],
                scaler=self.scaler_data,
                data_name=k,
                name_not_scalable=self.name_not_scalable)

            if pixel_or_pos_local or scatter_show_local:

                # xd_min, xd_max, yd_min, yd_max = min(self.x_pos), max(self.x_pos),
                #     min(self.y_pos), max(self.y_pos)
                x_pos_2D = self.data_dict['positions']['x_pos']
                y_pos_2D = self.data_dict['positions']['y_pos']
                xd_min, xd_max, yd_min, yd_max = x_pos_2D.min(), x_pos_2D.max(
                ), y_pos_2D.min(), y_pos_2D.max()
                xd_axis_min, xd_axis_max, yd_axis_min, yd_axis_max = \
                    _compute_equal_axes_ranges(xd_min, xd_max, yd_min, yd_max)

                xd_min, xd_max = _adjust_data_range_using_min_ratio(
                    xd_min, xd_max, xd_axis_max - xd_axis_min)
                yd_min, yd_max = _adjust_data_range_using_min_ratio(
                    yd_min, yd_max, yd_axis_max - yd_axis_min)

                # Adjust the direction of each axis depending on the direction in which encoder values changed
                #   during the experiment. Data is plotted starting from the upper-right corner of the plot
                if x_pos_2D[0, 0] > x_pos_2D[0, -1]:
                    xd_min, xd_max, xd_axis_min, xd_axis_max = xd_max, xd_min, xd_axis_max, xd_axis_min
                if y_pos_2D[0, 0] > y_pos_2D[-1, 0]:
                    yd_min, yd_max, yd_axis_min, yd_axis_max = yd_max, yd_min, yd_axis_max, yd_axis_min

            else:

                yd, xd = data_dict.shape

                xd_min, xd_max, yd_min, yd_max = 0, xd, 0, yd
                if (yd <= math.floor(xd / 100)) and (xd >= 200):
                    yd_min, yd_max = -math.floor(xd / 200), math.ceil(xd / 200)
                if (xd <= math.floor(yd / 100)) and (yd >= 200):
                    xd_min, xd_max = -math.floor(yd / 200), math.ceil(yd / 200)

                xd_axis_min, xd_axis_max, yd_axis_min, yd_axis_max = \
                    _compute_equal_axes_ranges(xd_min, xd_max, yd_min, yd_max)

            if self.scale_opt == 'Linear':

                low_ratio = self.limit_dict[k]['low'] / 100.0
                high_ratio = self.limit_dict[k]['high'] / 100.0
                if self.scaler_data is None:
                    minv = self.range_dict[k]['low']
                    maxv = self.range_dict[k]['high']
                else:
                    # Unfortunately, the new normalization procedure requires to recalculate min and max values
                    minv = np.min(data_dict)
                    maxv = np.max(data_dict)
                low_limit = (maxv - minv) * low_ratio + minv
                high_limit = (maxv - minv) * high_ratio + minv

                # Set some minimum range for the colorbar (otherwise it will have white fill)
                if math.isclose(low_limit, high_limit, abs_tol=2e-20):
                    if abs(low_limit) < 1e-20:  # The value is zero
                        dv = 1e-20
                    else:
                        dv = math.fabs(low_limit * 0.01)
                    high_limit += dv
                    low_limit -= dv

                if not scatter_show_local:
                    if grid_interpolate_local:
                        data_dict, _, _ = grid_interpolate(
                            data_dict, self.data_dict['positions']['x_pos'],
                            self.data_dict['positions']['y_pos'])
                    im = grid[i].imshow(data_dict,
                                        cmap=grey_use,
                                        interpolation=plot_interp,
                                        extent=(xd_min, xd_max, yd_max,
                                                yd_min),
                                        origin='upper',
                                        clim=(low_limit, high_limit))
                    grid[i].set_ylim(yd_axis_max, yd_axis_min)
                else:
                    xx = self.data_dict['positions']['x_pos']
                    yy = self.data_dict['positions']['y_pos']

                    # The following condition prevents crash if different file is loaded while
                    #    the scatter plot is open (PyXRF specific issue)
                    if data_dict.shape == xx.shape and data_dict.shape == yy.shape:
                        im = grid[i].scatter(
                            xx,
                            yy,
                            c=data_dict,
                            marker='s',
                            s=500,
                            alpha=1.0,  # Originally: alpha=0.8
                            cmap=grey_use,
                            vmin=low_limit,
                            vmax=high_limit,
                            linewidths=1,
                            linewidth=0)
                        grid[i].set_ylim(yd_axis_max, yd_axis_min)

                grid[i].set_xlim(xd_axis_min, xd_axis_max)

                grid_title = k
                grid[i].text(0,
                             1.01,
                             grid_title,
                             ha='left',
                             va='bottom',
                             transform=grid[i].axes.transAxes)

                grid.cbar_axes[i].colorbar(im)
                im.colorbar.formatter = im.colorbar.cbar_axis.get_major_formatter(
                )
                # im.colorbar.ax.get_xaxis().set_ticks([])
                # im.colorbar.ax.get_xaxis().set_ticks([], minor=True)
                grid.cbar_axes[i].ticklabel_format(style='sci',
                                                   scilimits=(-3, 4),
                                                   axis='both')

                #  Do not remove this code, may be useful in the future (Dmitri G.) !!!
                #  Print label for colorbar
                # cax = grid.cbar_axes[i]
                # axis = cax.axis[cax.orientation]
                # axis.label.set_text("$[a.u.]$")

            else:

                maxz = np.max(data_dict)
                # Set some reasonable minimum range for the colorbar
                #   Zeros or negative numbers will be shown in white
                if maxz <= 1e-30:
                    maxz = 1

                if not scatter_show_local:
                    if grid_interpolate_local:
                        data_dict, _, _ = grid_interpolate(
                            data_dict, self.data_dict['positions']['x_pos'],
                            self.data_dict['positions']['y_pos'])
                    im = grid[i].imshow(data_dict,
                                        norm=LogNorm(vmin=low_lim * maxz,
                                                     vmax=maxz,
                                                     clip=True),
                                        cmap=grey_use,
                                        interpolation=plot_interp,
                                        extent=(xd_min, xd_max, yd_max,
                                                yd_min),
                                        origin='upper',
                                        clim=(low_lim * maxz, maxz))
                    grid[i].set_ylim(yd_axis_max, yd_axis_min)
                else:
                    im = grid[i].scatter(
                        self.data_dict['positions']['x_pos'],
                        self.data_dict['positions']['y_pos'],
                        norm=LogNorm(vmin=low_lim * maxz, vmax=maxz,
                                     clip=True),
                        c=data_dict,
                        marker='s',
                        s=500,
                        alpha=1.0,  # Originally: alpha=0.8
                        cmap=grey_use,
                        linewidths=1,
                        linewidth=0)
                    grid[i].set_ylim(yd_axis_min, yd_axis_max)

                grid[i].set_xlim(xd_axis_min, xd_axis_max)

                grid_title = k
                grid[i].text(0,
                             1.01,
                             grid_title,
                             ha='left',
                             va='bottom',
                             transform=grid[i].axes.transAxes)

                grid.cbar_axes[i].colorbar(im)
                im.colorbar.formatter = im.colorbar.cbar_axis.get_major_formatter(
                )
                im.colorbar.ax.get_xaxis().set_ticks([])
                im.colorbar.ax.get_xaxis().set_ticks([], minor=True)
                im.colorbar.cbar_axis.set_minor_formatter(
                    mticker.LogFormatter())

            grid[i].get_xaxis().set_major_locator(
                mticker.MaxNLocator(nbins="auto"))
            grid[i].get_yaxis().set_major_locator(
                mticker.MaxNLocator(nbins="auto"))

            grid[i].get_xaxis().get_major_formatter().set_useOffset(False)
            grid[i].get_yaxis().get_major_formatter().set_useOffset(False)

        self.fig.suptitle(self.img_title, fontsize=20)
        self.fig.canvas.draw_idle()
Beispiel #24
0
import matplotlib
matplotlib.rcParams['mathtext.fontset'] = 'cm'
matplotlib.rcParams['mathtext.rm'] = 'serif'

# n_snap = 0
for n_snap in range(0, 56):

    # Set up figure and image grid
    fig = plt.figure(figsize=(fig_width * n_cols, 10 * n_rows), )
    grid = ImageGrid(
        fig,
        111,  # as in plt.subplot(111)
        nrows_ncols=(n_rows, n_cols),
        axes_pad=0.2,
        share_all=True,
        cbar_location="right",
        cbar_mode="single",
        cbar_size="5%",
        cbar_pad=0.1,
    )

    colormap = 'turbo'
    alpha = 0.6

    x_min, x_max = -3, 5.5
    y_min, y_max = 1.8, 8.

    ax = grid[0]

    dens_points = data_all[n_snap]['dens_points']
Beispiel #25
0
import pickle

saved_params = {k: v for k, v in params.items() if '_' not in k}
with open('q3_weights.pickle', 'wb') as handle:
    pickle.dump(saved_params, handle, protocol=pickle.HIGHEST_PROTOCOL)

# Q3.1.3
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

# weights for trained
params = pickle.load(open('q3_weights.pickle', 'rb'))

fig = plt.figure(1)
weights = params['Wlayer1']
grid = ImageGrid(fig, 111, nrows_ncols=(8, 8), axes_pad=0)

for i in range(8):
    for j in range(8):
        grid[8 * i + j].imshow(weights[:, 8 * i + j].reshape(32, 32))
plt.show()

# weights for initial params
fig = plt.figure(1)
initialize_weights(train_x.shape[1], hidden_size, params, 'layer1')
W_int = params['Wlayer1']
grid = ImageGrid(fig, 111, nrows_ncols=(8, 8), axes_pad=0)

for i in range(8):
    for j in range(8):
        grid[8 * i + j].imshow(W_int[:, 8 * i + j].reshape(32, 32))
with torch.no_grad():
    reco = hm.reconstruct(samp)

    reco_im = torch.squeeze(reco).reshape(28, 28)
    samp_im = torch.squeeze(samp).reshape(28, 28)

plt.imshow(samp_im)
plt.show()
plt.imshow(reco_im)
plt.show()

# <codecell>
with torch.no_grad():
    samp = hm.sample(25)
    samp_im = torch.squeeze(samp).reshape(25, 28, 28)

fig = plt.figure(figsize=(10, 10))
grid = ImageGrid(
    fig,
    111,  # similar to subplot(111)
    nrows_ncols=(5, 5),  # creates 2x2 grid of axes
    axes_pad=0.1,  # pad between axes in inch.
)

for ax, im in zip(grid, samp_im):
    ax.imshow(im)

fig.suptitle('Sample faces drawn from HM')
plt.show()

# TODO: debug same image problem
Beispiel #27
0
def draw_contour_map(dataset,
                     lats,
                     lons,
                     fname,
                     fmt='png',
                     gridshape=(1, 1),
                     clabel='',
                     ptitle='',
                     subtitles=None,
                     cmap=None,
                     clevs=None,
                     nlevs=10,
                     parallels=None,
                     meridians=None,
                     extend='neither',
                     aspect=8.5 / 2.5):
    '''
    Purpose::
        Create a multiple panel contour map plot.

    Input::
        dataset -  3d array of the field to be plotted with shape (nT, nLon, nLat)
        lats - array of latitudes
        lons - array of longitudes
        fname  - a string specifying the filename of the plot
        fmt  - an optional string specifying the filetype, default is .png
        gridshape - optional tuple denoting the desired grid shape (nrows, ncols) for arranging
                    the subplots.
        clabel - an optional string specifying the colorbar title
        ptitle - an optional string specifying plot title
        subtitles - an optional list of strings specifying the title for each subplot
        cmap - an string or optional matplotlib.colors.LinearSegmentedColormap instance
               denoting the colormap
        clevs - an optional list of ints or floats specifying contour levels
        nlevs - an optional integer specifying the target number of contour levels if
                clevs is None
        parallels - an optional list of ints or floats for the parallels to be drawn
        meridians - an optional list of ints or floats for the meridians to be drawn
        extend - an optional string to toggle whether to place arrows at the colorbar
             boundaries. Default is 'neither', but can also be 'min', 'max', or
             'both'. Will be automatically set to 'both' if clevs is None.
    '''
    # Handle the single plot case. Meridians and Parallels are not labeled for
    # multiple plots to save space.
    if dataset.ndim == 2 or (dataset.ndim == 3 and dataset.shape[0] == 1):
        if dataset.ndim == 2:
            dataset = dataset.reshape(1, *dataset.shape)
        mlabels = [0, 0, 0, 1]
        plabels = [1, 0, 0, 1]
    else:
        mlabels = [0, 0, 0, 0]
        plabels = [0, 0, 0, 0]

    # Make sure gridshape is compatible with input data
    nplots = dataset.shape[0]
    gridshape = _best_grid_shape(nplots, gridshape)

    # Set up the figure
    fig = plt.figure()
    fig.set_size_inches((8.5, 11.))
    fig.dpi = 300

    # Make the subplot grid
    grid = ImageGrid(fig,
                     111,
                     nrows_ncols=gridshape,
                     axes_pad=0.3,
                     share_all=True,
                     add_all=True,
                     ngrids=nplots,
                     label_mode='L',
                     cbar_mode='single',
                     cbar_location='bottom',
                     cbar_size=.15,
                     cbar_pad='0%')

    # Determine the map boundaries and construct a Basemap object
    lonmin = lons.min()
    lonmax = lons.max()
    latmin = lats.min()
    latmax = lats.max()
    m = Basemap(projection='cyl',
                llcrnrlat=latmin,
                urcrnrlat=latmax,
                llcrnrlon=lonmin,
                urcrnrlon=lonmax,
                resolution='l')

    # Convert lats and lons to projection coordinates
    if lats.ndim == 1 and lons.ndim == 1:
        lons, lats = np.meshgrid(lons, lats)

    # Calculate contour levels if not given
    if clevs is None:
        # Cut off the tails of the distribution
        # for more representative contour levels
        clevs = _nice_intervals(dataset, nlevs)
        extend = 'both'

    cmap = plt.get_cmap(cmap)

    # Create default meridians and parallels. The interval between
    # them should be 1, 5, 10, 20, 30, or 40 depending on the size
    # of the domain
    length = max((latmax - latmin), (lonmax - lonmin)) / 5
    if length <= 1:
        dlatlon = 1
    elif length <= 5:
        dlatlon = 5
    else:
        dlatlon = np.round(length, decimals=-1)
    if meridians is None:
        meridians = np.r_[np.arange(0, -180, -dlatlon)[::-1],
                          np.arange(0, 180, dlatlon)]
    if parallels is None:
        parallels = np.r_[np.arange(0, -90, -dlatlon)[::-1],
                          np.arange(0, 90, dlatlon)]

    x, y = m(lons, lats)
    for i, ax in enumerate(grid):
        # Load the data to be plotted
        data = dataset[i]
        m.ax = ax

        # Draw the borders for coastlines and countries
        m.drawcoastlines(linewidth=1)
        m.drawcountries(linewidth=.75)

        # Draw parallels / meridians
        m.drawmeridians(meridians, labels=mlabels, linewidth=.75, fontsize=10)
        m.drawparallels(parallels, labels=plabels, linewidth=.75, fontsize=10)

        # Draw filled contours
        cs = m.contourf(x, y, data, cmap=cmap, levels=clevs, extend=extend)

        # Add title
        if subtitles is not None:
            ax.set_title(subtitles[i], fontsize='small')

    # Add colorbar
    cbar = fig.colorbar(cs,
                        cax=ax.cax,
                        drawedges=True,
                        orientation='horizontal',
                        extendfrac='auto')
    cbar.set_label(clabel)
    cbar.set_ticks(clevs)
    cbar.ax.xaxis.set_ticks_position('none')
    cbar.ax.yaxis.set_ticks_position('none')

    # This is an ugly hack to make the title show up at the correct height.
    # Basically save the figure once to achieve tight layout and calculate
    # the adjusted heights of the axes, then draw the title slightly above
    # that height and save the figure again
    fig.savefig(TemporaryFile(), bbox_inches='tight', dpi=fig.dpi)
    ymax = 0
    for ax in grid:
        bbox = ax.get_position()
        ymax = max(ymax, bbox.ymax)

    # Add figure title
    fig.suptitle(ptitle, y=ymax + .06, fontsize=16)
    fig.savefig('%s.%s' % (fname, fmt), bbox_inches='tight', dpi=fig.dpi)
    fig.clf()
Beispiel #28
0
from orimat import qvec
from settings import e0, e1, e2, hist_bincenters, hist_binedges, a0_reciprocal

do_plotsurface = False  # set to True to plot surface orientation

runs = [4, 3, 8, 9, 7]
qarray = hist_bincenters[0] * a0_reciprocal
xmin, xmax = hist_binedges[1][0] * a0_reciprocal, hist_binedges[1][
    -1] * a0_reciprocal
ymin, ymax = hist_binedges[0][0] * a0_reciprocal, hist_binedges[0][
    -1] * a0_reciprocal

## plot the 2d images
fig = plt.figure(figsize=(6, 8))
grid = ImageGrid(fig, 111, nrows_ncols=(1,5), axes_pad=0.15, \
                 share_all=True, cbar_location='right', \
                 cbar_mode='single', cbar_size="15%", cbar_pad=0.15)
for i_run, run in enumerate(runs):
    ax = grid[i_run]
    filename = helper.getparam("filenames", run)
    sample = helper.getparam("samples", runs[i_run])
    orimat = np.load("orimats/%s.npy" % (helper.getparam("orimats", run)))
    data_orig = np.load("data_hklmat/%s_interpolated_xy.npy" % filename)
    im = ax.imshow(np.log10(data_orig), extent=[xmin, xmax, ymin, ymax], \
                   origin='lower', interpolation='nearest', \
                   vmax=3.0)

    if do_plotsurface:
        surf_norm = orimat.dot(qvec(50., 25., 90., 0.))
        surf_norm = surf_norm / np.linalg.norm(surf_norm)
        print "hkl of surface =", surf_norm
Beispiel #29
0
 def save_fig(self, target, reconstruction):
   fig = plt.figure(1, (1,2))
   grid = ImageGrid(fig, 111, nrows_ncols = (1, 2), axes_pad=0.1)
   grid[0].imshow(target)
   grid[1].imshow(reconstruction)
   plt.savefig(self.save_fig_name)
def plot_av_vs_nhi(nhi_image,
                   av_image,
                   limits=None,
                   savedir='./',
                   filename=None,
                   show=False,
                   scale=['linear', 'linear'],
                   returnimage=False,
                   hess_binsize=None,
                   title='',
                   plot_type='hexbin',
                   color_scale='linear'):

    # Import external modules
    import numpy as np
    import math
    import pyfits as pf
    import matplotlib.pyplot as plt
    import matplotlib
    from matplotlib import cm
    from mpl_toolkits.axes_grid1 import ImageGrid

    # Drop the NaNs from the images
    indices = np.where((nhi_image == nhi_image) &\
                       (av_image == av_image) &\
                       (nhi_image > 0) &\
                       (av_image > 0))

    try:
        nhi_image_nonans = nhi_image[indices]
        av_image_nonans = av_image[indices]

        if type(av_image_error) is float:
            av_image_error_nonans = sd_image_error * \
                    np.ones(av_image[indices].shape)
        else:
            av_image_error_nonans = sd_image_error[indices]

        if type(nhi_image_error) is np.ndarray:
            nhi_image_error_nonans = nhi_image_error[indices]
        else:
            nhi_image_error_nonans = nhi_image_error * \
                    np.ones(nhi_image[indices].shape)
    except NameError:
        no_errors = True

    # Create figure
    # Set up plot aesthetics
    plt.clf()
    plt.rcdefaults()
    colormap = plt.cm.gist_ncar
    #color_cycle = [colormap(i) for i in np.linspace(0, 0.9, len(flux_list))]
    fig_size = (4, 4)
    font_scale = 10
    params = {  #'backend': .pdf',
        'axes.labelsize': font_scale,
        'axes.titlesize': font_scale,
        'text.fontsize': font_scale,
        'legend.fontsize': font_scale * 3 / 4,
        'xtick.labelsize': font_scale,
        'ytick.labelsize': font_scale,
        'font.weight': 500,
        'axes.labelweight': 500,
        'text.usetex': False,
        'figure.figsize': fig_size,
        #'axes.color_cycle': color_cycle # colors of different plots
    }
    plt.rcParams.update(params)

    if plot_type == 'scatter':
        cbar_mode = 'None'
    else:
        cbar_mode = 'single'

    # Create figure
    plt.clf()
    fig = plt.figure()
    imagegrid = ImageGrid(fig, (1, 1, 1),
                          nrows_ncols=(1, 1),
                          ngrids=1,
                          axes_pad=0.25,
                          aspect=False,
                          label_mode='L',
                          share_all=True,
                          cbar_mode=cbar_mode,
                          cbar_pad=0.1,
                          cbar_size=0.2)

    ax = imagegrid[0]

    if plot_type is 'hexbin':
        if color_scale == 'linear':
            image = ax.hexbin(nhi_image_nonans.ravel(),
                              av_image_nonans.ravel(),
                              mincnt=1,
                              xscale=scale[0],
                              yscale=scale[1],
                              cmap=cm.gist_stern)
            cb = ax.cax.colorbar(image, )
            # Write label to colorbar
            cb.set_label_text('Bin Counts', )
        elif color_scale == 'log':
            image = ax.hexbin(nhi_image_nonans.ravel(),
                              av_image_nonans.ravel(),
                              norm=matplotlib.colors.LogNorm(),
                              mincnt=1,
                              xscale=scale[0],
                              yscale=scale[1],
                              gridsize=(100, 200),
                              cmap=cm.gist_stern)
            cb = ax.cax.colorbar(image, )
            # Write label to colorbar
            cb.set_label_text('Bin Counts', )
        # Adjust color bar of density plot
        #cb = image.colorbar(image)
        #cb.set_label('Bin Counts')
    elif plot_type is 'scatter':
        image = ax.scatter(nhi_image_nonans.ravel(),
                           av_image_nonans.ravel(),
                           alpha=0.3,
                           color='k')
        ax.set_xscale(scale[0])
        ax.set_yscale(scale[1])

    if limits is not None:
        ax.set_xlim(limits[0], limits[1])
        ax.set_ylim(limits[2], limits[3])

    # Adjust asthetics
    ax.set_xlabel(r'$N(HI)$ (10$^{20}$ cm$^{-2}$)')
    ax.set_ylabel(r'A$_{\rm V}$ (mag)')
    ax.set_title(title)
    ax.grid(True)

    if filename is not None:
        plt.savefig(savedir + filename, bbox_inches='tight', dpi=600)
    if show:
        fig.show()
    if returnimage:
        return correlations_image