Example #1
def colormap2d(xargs, yargs, cmap=None, scale=None):
    """Create a ScalarMappable2D object to map from a 2D parameter space to RGB colors.

    xargs : (N,) array_like of scalars
        Values determining the extrema of the first parameter-space dimension.
    yargs : (N,) array_like of scalars
        Values determining the extrema of the second parameter-space dimension.
    scale : one or two values, each either str or `None`

    smap2d : `ScalarMappable2D`
        Object to handle conversion from parameter to color spaces.  Use ```to_rgba(xx, yy)``.

    # Choose a default 2d mapping
    if cmap is None: cmap = _cmap2d_hsv_lin

    scale = list(np.atleast_1d(scale))
    if np.size(scale) == 1:
        scale = 2 * scale
    elif np.size(scale) != 2:
        raise ValueError("`scale` must be a single or pair of values.")

    if np.size(xargs) == 1: xargs = [0, np.int(xargs) - 1]
    if np.size(yargs) == 1: yargs = [0, np.int(yargs) - 1]

    if scale[0] is None: scale[0] = zmath._infer_scale(xargs)
    if scale[1] is None: scale[1] = zmath._infer_scale(yargs)

    xlog = zplot._scale_to_log_flag(scale[0])
    if xlog: xfilter = 'g'
    else: xfilter = None

    ylog = zplot._scale_to_log_flag(scale[1])
    if ylog: yfilter = 'g'
    else: yfilter = None

    xmin, xmax = zmath.minmax(xargs, filter=xfilter)
    ymin, ymax = zmath.minmax(yargs, filter=yfilter)

    if xlog: xnorm = mpl.colors.LogNorm(vmin=xmin, vmax=xmax)
    else: xnorm = mpl.colors.Normalize(vmin=xmin, vmax=xmax)

    if ylog: ynorm = mpl.colors.LogNorm(vmin=ymin, vmax=ymax)
    else: ynorm = mpl.colors.Normalize(vmin=ymin, vmax=ymax)

    smap2d = ScalarMappable2D([xnorm, ynorm],
    return smap2d
Example #2
 def interp(xx):
         res = 10**interp_ll(np.log10(xx))
     except ValueError:
         logging.error("ValueError for argument: '{}'".format(xx))
         logging.error("ValueError for argument: log: '{}'".format(
         for gg in interp_ll.grid:
     return res
Example #3
def _set_extrema(extrema, vals, filter=None, lo=None, hi=None):
    _extr = None
    for vv in vals:
        use_vv = np.array(vv)
        if lo is not None:
            use_vv = use_vv[use_vv > lo]
        if hi is not None:
            use_vv = use_vv[use_vv < hi]
        _extr = zmath.minmax(use_vv, filter=filter, prev=_extr, stretch=0.05)

    if extrema is None: new_extr = _extr
    else:               new_extr = np.asarray(extrema)
    if new_extr[0] is None: new_extr[0] = _extr[0]
    if new_extr[1] is None: new_extr[1] = _extr[1]
    new_extr = new_extr.astype(np.float64)
    return new_extr
Example #4
def zoom(ax, loc, axis='x', scale=2.0):
    """Zoom-in at a certain location on the given axes.

    # Choose functions based on target axis
    if axis == 'x':
        axScale = ax.get_xscale()
        lim = ax.get_xlim()
        set_lim = ax.set_xlim
    elif axis == 'y':
        axScale = ax.get_yscale()
        lim = ax.get_ylim()
        set_lim = ax.set_ylim
        raise ValueError("Unrecognized ``axis`` = '%s'!!" % (str(axis)))

    lim = np.array(lim)

    # Determine axis scaling
    if axScale.startswith('lin'):
        log = False
    elif axScale.startswith('log'):
        log = True
        raise ValueError("``axScale`` '%s' not implemented!" % (str(axScale)))

    # Convert to log if appropriate
    if log:
        lim = np.log10(lim)
        loc = np.log10(loc)

    # Find new axis bounds
    delta = np.diff(zmath.minmax(lim))[0]
    lim = np.array([loc - (0.5 / scale) * delta, loc + (0.5 / scale) * delta])
    # Convert back to linear if appropriate
    if log: lim = np.power(10.0, lim)

    return lim
Example #5
def _plotPlotZoom(ax1, ax2, zoomLoc, xx, yy, zz=None, zoomScale=20.0):
    Plot a full range in one axis, and a zoom-in in another.
    ALPHA = 0.5
    SIZE = 10
    COL = 'k'

    if (zz is None): zz = yy

    # Plot for Full Range
    ax1.plot(xx, yy, '-', color=COL, alpha=ALPHA)
    ax1.scatter(xx, zz, s=SIZE, color=COL, alpha=ALPHA)

    # Plot Density vs. Energy for Zoom
    ax2.plot(xx, yy, '-', color=COL, alpha=ALPHA)
    ax2.scatter(xx, zz, s=SIZE, color=COL, alpha=ALPHA)
    xlim = zplot.zoom(ax2, zoomLoc, axis='x', scale=zoomScale)
    ylim = zplot.limits(xx, yy, xlim)
Example #6
def plotCorrelationGrid(data, figure=None, style='scatter', confidence=True, contours=True,
                        pars_scales=None, hist_scales=None, hist_bins=None, names=None, fs=12):
    Plot a grid of correlation graphs, showing histograms of arrays and correlations between pairs.

        data <scalar>[N][M]        : ``N`` different parameters, with ``M`` values each

        figure      <obj>          : ``matplotlib.figure.Figure`` object on which to plot
        style       <str>          : what style of correlation plots to make
                                     - 'scatter'

        confidence  <bool>         : Add confidence intervals to histogram plots
        contours    <bool>         : Add contour lines to correlation plots
        pars_scales  <scalar>([N]) : What scale to use for all (or each) parameter {'lin', 'log'}
        hist_scales <scalar>([N])  : What y-axis scale to use for all (or each) histogram
        hist_bins   <scalar>([N])  : Number of bins for all (or each) histogram

        figure <obj>      : ``matplotlib.figure.Figure`` object
        axes   <obj>[N,N] : array of ``matplotlib.axes`` objects


    npars = len(data)

    # Set default scales for each parameter
    if(pars_scales is None):            pars_scales = ['linear']*npars
    elif(isinstance(pars_scales, str)): pars_scales = [pars_scales]*npars

    # Set default scales for each histogram (counts)
    if(hist_scales is None):            hist_scales = ['linear']*npars
    elif(isinstance(hist_scales, str)): hist_scales = [hist_scales]*npars

    # Convert scaling strings to appropriate formats
    for ii in range(npars):
        if(pars_scales[ii].startswith('lin')): pars_scales[ii] = 'linear'
        elif(pars_scales[ii].startswith('log')): pars_scales[ii] = 'log'
        if(hist_scales[ii].startswith('lin')): hist_scales[ii] = 'linear'
        elif(hist_scales[ii].startswith('log')): hist_scales[ii] = 'log'

    # Set default bins
    if(hist_bins is None):                         hist_bins = [40]*npars
    elif(isinstance(hist_bins, numbers.Integral)): hist_bins = [hist_bins]*npars

    # Setup Figure and Axes
    # ---------------------
    #     Create Figure
    if(figure is None): figure = plt.figure()

    # Axes are already on figure
    if(len(figure.axes) > 0):
        # Make sure the number of axes is correct
        if(len(figure.axes) != npars*npars):
            raise RuntimeError("``figure`` axes must be {0:d}x{0:d}!".format(npars))

    # Create axes
        # Divide figure evenly with padding
        dx = (_RIGHT-_LEFT)/npars
        dy = (_TOP-_BOT)/npars

        # Rows
        for ii in range(npars):
            # Columns
            for jj in range(npars):
                ax = figure.add_axes([_LEFT+jj*dx, _TOP-(ii+1)*dy, dx, dy])
                # Make upper-right half of figure invisible
                if(jj > ii):

    axes = np.array(figure.axes)
    # Reshape to grid for convenience
    axes = axes.reshape(npars, npars)

    # Plot Correlations and Histograms
    # --------------------------------
    lims = []
    for ii in range(npars):

        for jj in range(npars):
            if(jj > ii): continue

            # Histograms
            if(ii == jj):
                zplot.plotHistBars(axes[ii, jj], data[ii], bins=hist_bins[ii],
                                   scalex=pars_scales[ii], conf=True)

            # Correlations
                if(style == 'scatter'):
                    zplot.plotScatter(axes[ii, jj], data[jj], data[ii],
                                      scalex=pars_scales[jj], scaley=pars_scales[ii], cont=contours)
                    raise RuntimeError("``style`` '%s' is not implemented!" % (style))

    # Configure Axes
    _config_axes(axes, lims, pars_scales, hist_scales, names, fs)

    return figure, axes
Example #7
def draw_mesh(ax,

    if (not lines_flag) and (vals is None):
        raise ValueError("Nothing is being plotted!")

    ndt = mesh['ndt_tot']
    xyz = mesh['xyz_edges']
    edge_list = mesh['edge_list']
    nedge_offset = mesh['nedge_offset']
    nedges = mesh['nedges']
    NDIM = 2
    NSIDE_MAX = 20

    xyz = xyz.reshape(ndt, NDIM)

    poly = np.zeros((NSIDE_MAX, NDIM))
    if lines_flag:
        lines = np.full((2 * len(edge_list), NDIM), np.nan)
    tot_num = len(nedge_offset)

    mult = 1 if (periodic is None) else 2

    if periodic is not None:
        periodic = [np.array(pp) if pp is not None else pp for pp in periodic]
        # np.atleast_2d(periodic)
    if region is not None:
        # region = [np.array(pp) if pp is not None else pp for pp in region]
        region = np.atleast_2d(region)

    if vals is not None:
        patches = np.empty(mult * tot_num, dtype=object)
        colors = np.zeros(mult * tot_num)

    cnt = 0
    valid = np.zeros(mult * tot_num, dtype=bool)

    def add_cell(ee, ne, poly, cnt, end=False):
        if end:
            ff = mult * tot_num - 1 - ee
            ff = ee

        if lines_flag:
            lines[cnt:cnt + ne, :] = poly[:ne, :]

        if vals is not None:
            inc = 0
            if fix_poly and np.allclose(poly[0, :], poly[ne - 1, :]):
                ne = ne - 1
                inc = 1

            pat = mpl.patches.Polygon(poly[:ne])
            patches[ff] = pat
            colors[ff] = vals[ee]
            ne = ne + inc

        valid[ff] = True
        cnt = cnt + ne + 1
        return cnt

    pers = 0
    for ee in tqdm.tqdm(range(tot_num), total=tot_num, leave=False):
        oo = nedge_offset[ee]
        ne = nedges[ee]
        ll = edge_list[oo]
        lo = xyz[ll]
        poly[0] = lo
        if ne >= NSIDE_MAX:
            err = "Number of edges for element {} = {}, exceeds max {}!".format(
                ee, ne, NSIDE_MAX)
            raise ValueError(err)

        for ff in range(1, ne):
            hh = edge_list[oo + ff]
            hi = xyz[hh]
            poly[ff] = hi

        if (region is not None) and (not any_within(poly[:ne], region)):

        cnt = add_cell(ee, ne, poly, cnt, end=False)

        if periodic is None:

        for dd in range(NDIM):
            if periodic[dd] is None:

            if np.any((poly[:ne, dd] < periodic[dd][0])):
                dup = np.copy(poly[:ne, :])
                dup[:, dd] += (periodic[dd][1] - periodic[dd][0])
                cnt = add_cell(ee, ne, dup, cnt, end=True)
                pers += 1
            elif np.any(poly[:ne, dd] > periodic[dd][1]):
                dup = np.copy(poly[:ne, :])
                dup[:, dd] -= (periodic[dd][1] - periodic[dd][0])
                cnt = add_cell(ee, ne, dup, cnt, end=True)
                pers += 1

        # if cnt > 1000:
        #     break

    extr = zmath.minmax(colors[valid])
    if vals is not None:
        if smap is None:
            smap = zplot.smap(extr, cmap='viridis')

        p = mpl.collections.PatchCollection(patches[valid],

    if lines_flag:
        lines = lines[:cnt, :].T
        ax.plot(*lines, **kwargs)

    return smap, extr
Example #8
def plot_proj_1d(path,
    num_snaps = len(lysis.readio.snap_files(path))
    fname = "proj1d_{}-{}.png".format(scatter_param.lower(),

    init = None
    mean = None
    if snaps is None:
        snaps = range(num_snaps)

    yextr = None
    for snap in tqdm.tqdm(snaps):
        rads, scat, hist, edges = lysis.plot.rad_proj_from_2d(
        if extr is None:
            extr = zmath.minmax(edges)
        if init is None:
            init = np.copy(hist)

        idx = (hist > 0.0)
        if mean is None:
            mean = np.zeros_like(hist)
            mean[idx] = hist[idx]
            mean[idx] = mean[idx] + hist[idx]

        yextr = zmath.minmax(scat, prev=yextr, log_stretch=0.1, filter='>')

    mean = mean / num_snaps
    if ylim is None:
        ylim = yextr

    output_fnames = []
    for snap_num in tqdm.tqdm(snaps):
        fig = lysis.plot.plot_proj_1d_snap(path,

        fig.axes[0].set(xlim=xlim, ylim=ylim)
        _fname = lysis.save_fig(fig,
        if _fname is None:
            print(_fname, fname)
            raise RuntimeError()


    if framerate is None:
        framerate = len(snaps) / 20
        framerate = int(np.clip(framerate, 2, 15))

    print("Saved to (e.g.) '{}'".format(output_fnames[0]))
    if movie:
        movie_fname = make_movie(path,
                                 fname.replace('.png', '.mp4'),
        print("Saved movie to '{}'".format(movie_fname))

Example #9
def subhaloRadialProfiles(run, snapNum, subhalo, radBins=None, nbins=NUM_RAD_BINS,
                          mostBound=None, verbose=True):
    Construct binned, radial profiles of density for each particle species.

    Profiles for the velocity dispersion and gravitational potential are also constructed for
    all particle types together.

       run       <int>    : illustris simulation run number {1, 3}
       snapNum   <int>    : illustris simulation snapshot number {1, 135}
       subhalo   <int>    : subhalo index number for target snapshot
       radBins   <flt>[N] : optional, right-edges of radial bins in simulation units
       nbins     <int>    : optional, numbers of bins to create if ``radBins`` is `None`
       mostBound <int>    : optional, ID number of the most-bound particle for this subhalo
       verbose   <bool>   : optional, print verbose output

       radBins   <flt>[N]   : coordinates of right-edges of ``N`` radial bins
       posRef    <flt>[3]   : coordinates in simulation box of most-bound particle (used as C.O.M.)
       partTypes <int>[M]   : particle type numbers for ``M`` types, (``illpy_lib.constants.PARTICLE``)
       partNames <str>[M]   : particle type strings for each type
       numsBins  <int>[M, N] : binned number of particles for ``M`` particle types, ``N`` bins each
       massBins  <flt>[M, N] : binned radial mass profile
       densBins  <flt>[M, N] : binned mass density profile
       potsBins  <flt>[N]   : binned gravitational potential energy profile for all particles
       dispBins  <flt>[N]   : binned velocity dispersion profile for all particles


    if verbose: print(" - - Profiler.subhaloRadialProfiles()")

    if verbose: print(" - - - Loading subhalo partile data")
    # Redirect output during this call
    with zio.StreamCapture():
        partData, partTypes = Subhalo.importSubhaloParticles(run, snapNum, subhalo, verbose=False)

    partNums = [pd['count'] for pd in partData]
    partNames = [PARTICLE.NAMES(pt) for pt in partTypes]
    numPartTypes = len(partNums)

    # Find the most-bound Particle
    #  ----------------------------

    posRef = None

    # If no particle ID is given, find it
    if (mostBound is None):
        # Get group catalog
        mostBound = Subhalo.importGroupCatalogData(
            run, snapNum, subhalos=subhalo, fields=[SUBHALO.MOST_BOUND])

    if (mostBound is None):
        warnStr  = "Could not find mostBound particle ID Number!"
        warnStr += "Run %d, Snap %d, Subhalo %d" % (run, snapNum, subhalo)
        warnings.warn(warnStr, RuntimeWarning)
        return None

    thisStr = "Run %d, Snap %d, Subhalo %d, Bound ID %d" % (run, snapNum, subhalo, mostBound)
    if verbose: print((" - - - - {:s} : Loaded {:s} particles".format(thisStr, str(partNums))))

    # Find the most-bound particle, store its position
    for pdat, pname in zip(partData, partNames):
        # Skip, if no particles of this type
        if (pdat['count'] == 0): continue
        inds = np.where(pdat[SNAPSHOT.IDS] == mostBound)[0]
        if (len(inds) == 1):
            if verbose: print((" - - - Found Most Bound Particle in '{:s}'".format(pname)))
            posRef = pdat[SNAPSHOT.POS][inds[0]]

    # } pdat, pname

    # Set warning and return ``None`` if most-bound particle is not found
    if (posRef is None):
        warnStr = "Could not find most bound particle in snapshot! %s" % (thisStr)
        warnings.warn(warnStr, RuntimeWarning)
        return None

    mass = np.zeros(numPartTypes, dtype=object)
    rads = np.zeros(numPartTypes, dtype=object)
    pots = np.zeros(numPartTypes, dtype=object)
    disp = np.zeros(numPartTypes, dtype=object)
    radExtrema = None

    # Iterate over all particle types and their data
    #  ==============================================

    if verbose: print(" - - - Extracting and processing particle properties")
    for ii, (data, ptype) in enumerate(zip(partData, partTypes)):

        # Make sure the expected number of particles are found
        if (data['count'] != partNums[ii]):
            warnStr  = "%s" % (thisStr)
            warnStr += "Type '%s' count mismatch after loading!!  " % (partNames[ii])
            warnStr += "Expecting %d, Retrieved %d" % (partNums[ii], data['count'])
            warnings.warn(warnStr, RuntimeWarning)
            return None

        # Skip if this particle type has no elements
        #    use empty lists so that call to ``np.concatenate`` below works (ignored)
        if (data['count'] == 0):
            mass[ii] = []
            rads[ii] = []
            pots[ii] = []
            disp[ii] = []

        # Extract positions from snapshot, make sure reflections are nearest most-bound particle
        posn = reflectPos(data[SNAPSHOT.POS], center=posRef)

        # DarkMatter Particles all have the same mass, store that single value
        if (ptype == PARTICLE.DM): mass[ii] = [GET_ILLUSTRIS_DM_MASS(run)]
        else:                       mass[ii] = data[SNAPSHOT.MASS]

        # Convert positions to radii from ``posRef`` (most-bound particle), and find radial extrema
        rads[ii] = zmath.dist(posn, posRef)
        pots[ii] = data[SNAPSHOT.POT]
        disp[ii] = data[SNAPSHOT.SUBF_VDISP]
        radExtrema = zmath.minmax(rads[ii], prev=radExtrema, nonzero=True)

    # Create Radial Bins
    #  ------------------

    # Create radial bin spacings, these are the upper-bound radii
    if (radBins is None):
        radExtrema[0] = radExtrema[0]*0.99
        radExtrema[1] = radExtrema[1]*1.01
        radBins = zmath.spacing(radExtrema, scale='log', num=nbins)

    # Find average bin positions, and radial bin (shell) volumes
    numBins = len(radBins)
    binVols = np.zeros(numBins)
    for ii in range(len(radBins)):
        if (ii == 0): binVols[ii] = np.power(radBins[ii], 3.0)
        else:          binVols[ii] = np.power(radBins[ii], 3.0) - np.power(radBins[ii-1], 3.0)

    # Bin Properties for all Particle Types
    # -------------------------------------
    densBins = np.zeros([numPartTypes, numBins], dtype=DTYPE.SCALAR)    # Density
    massBins = np.zeros([numPartTypes, numBins], dtype=DTYPE.SCALAR)    # Mass
    numsBins = np.zeros([numPartTypes, numBins], dtype=DTYPE.INDEX)    # Count of particles

    # second dimension to store averages [0] and standard-deviations [1]
    potsBins = np.zeros([numBins, 2], dtype=DTYPE.SCALAR)               # Grav Potential Energy
    dispBins = np.zeros([numBins, 2], dtype=DTYPE.SCALAR)               # Velocity dispersion

    # Iterate over particle types
    if verbose: print(" - - - Binning properties by radii")
    for ii, (data, ptype) in enumerate(zip(partData, partTypes)):

        # Skip if this particle type has no elements
        if (data['count'] == 0): continue

        # Get the total mass in each bin
        numsBins[ii, :], massBins[ii, :] = zmath.histogram(rads[ii], radBins, weights=mass[ii],
                                                           edges='right', func='sum', stdev=False)

        # Divide by volume to get density
        densBins[ii, :] = massBins[ii, :]/binVols

    if verbose: print((" - - - - Binned {:s} particles".format(str(np.sum(numsBins, axis=1)))))

    # Consistency check on numbers of particles
    # -----------------------------------------
    #      The total number of particles ``numTot`` shouldn't necessarily be in bins.
    #      The expected number of particles ``numExp`` are those that are within the bounds of bins

    for ii in range(numPartTypes):

        numExp = np.size(np.where(rads[ii] <= radBins[-1])[0])
        numAct = np.sum(numsBins[ii])
        numTot = np.size(rads[ii])

        # If there is a discrepancy return ``None`` for error
        if (numExp != numAct):
            warnStr  = "%s\nType '%s' count mismatch after binning!" % (thisStr, partNames[ii])
            warnStr += "\nExpecting %d, Retrieved %d" % (numExp, numAct)
            warnings.warn(warnStr, RuntimeWarning)
            return None

        # If a noticeable number of particles are not binned, warn, but still continue
        elif (numAct < numTot-10 and numAct < 0.9*numTot):
            warnStr  = "%s : Type %s" % (thisStr, partNames[ii])
            warnStr += "\nTotal = %d, Expected = %d, Binned = %d" % (numTot, numExp, numAct)
            warnStr += "\nBin Extrema = %s" % (str(zmath.minmax(radBins)))
            warnStr += "\nRads = %s" % (str(rads[ii]))
            warnings.warn(warnStr, RuntimeWarning)
            raise RuntimeError("")

    # Convert list of arrays into 1D arrays of all elements
    rads = np.concatenate(rads)
    pots = np.concatenate(pots)
    disp = np.concatenate(disp)

    # Bin Grav Potentials
    counts, aves, stds = zmath.histogram(rads, radBins, weights=pots,
                                         edges='right', func='ave', stdev=True)
    potsBins[:, 0] = aves
    potsBins[:, 1] = stds

    # Bin Velocity Dispersion
    counts, aves, stds = zmath.histogram(rads, radBins, weights=disp,
                                         edges='right', func='ave', stdev=True)
    dispBins[:, 0] = aves
    dispBins[:, 1] = stds

    return radBins, posRef, mostBound, partTypes, partNames, \
        numsBins, massBins, densBins, potsBins, dispBins
Example #10
def plot_grid(grid, grid_names, temps, valid, interp=None):
    import matplotlib.pyplot as plt
    import zcode.plot as zplot

    extr = zmath.minmax(temps, filter='>')
    smap = zplot.colormap(extr, 'viridis')

    # bads = valid & np.isclose(temps, 0.0)

    num = len(grid)
    fig, axes = plt.subplots(figsize=[14, 14], nrows=num, ncols=num)
    plt.subplots_adjust(hspace=0.4, wspace=0.4)

    def_idx = [-4, -4, 4, -4]

    for (ii, jj), ax in np.ndenumerate(axes):
        if ii < jj:

        ax.set(xscale='log', yscale='log')
        xx = grid[jj]
        if ii == jj:
            # print(grid_names[ii], zmath.minmax(grid[ii], filter='>'))
            # idx = list(range(num))
            # idx.pop(ii)
            # idx = tuple(idx)
            # vals = np.mean(temps, axis=idx)

            idx = [
                slice(None) if aa == ii else def_idx[aa] for aa in range(num)
            vals = temps[tuple(idx)]
            ax.plot(xx, vals, 'k-')

            if interp is not None:
                num_test = 10
                test = [
                    np.ones(num_test) * grid[aa][def_idx[aa]]
                    for aa in range(num)
                test[ii] = zmath.spacing(grid[ii], 'log', num_test)
                test_vals = [interp(tt) for tt in np.array(test).T]
                ax.plot(test[ii], test_vals, 'r--')

            # bad_vals = np.count_nonzero(bads, axis=idx)
            # tw = ax.twinx()
            # tw.plot(xx, bad_vals, 'r--')

            # print(ii, jj)
            # print("\t", ii, grid_names[ii], zmath.minmax(grid[ii], filter='>'))
            # print("\t", jj, grid_names[jj], zmath.minmax(grid[jj], filter='>'))
            # idx = [0, 1, 2, 3]
            # idx.pop(np.max([ii, jj]))
            # idx.pop(np.min([ii, jj]))
            # vals = np.mean(temps, axis=tuple(idx))

            # idx = [slice(None) if aa in [ii, jj] else num//2 for aa in range(num)]
            idx = [
                slice(None) if aa in [ii, jj] else def_idx[aa]
                for aa in range(num)
            vals = temps[tuple(idx)]
            if len(vals) == 0:

            yy = grid[ii]
            xx, yy = np.meshgrid(xx, yy, indexing='ij')
            ax.pcolor(xx, yy, vals, cmap=smap.cmap, norm=smap.norm)

            if np.count_nonzero(vals > 0.0) == 0:

            tit = "{:.1e}, {:.1e}".format(*zmath.minmax(vals, filter='>'))
            ax.set_title(tit, size=10)

            # bad_vals = np.count_nonzero(bads, axis=tuple(idx))
            # idx = (bad_vals > 0.0)
            # aa = xx[idx]
            # bb = yy[idx]
            # cc = bad_vals[idx]
            # ax.scatter(aa, bb, s=2*cc**2, color='0.5', alpha=0.5)
            # ax.scatter(aa, bb, s=cc**2, color='r')

            if interp is not None:
                for kk in range(10):
                    idx = (vals > 0.0)
                    x0 = 10**np.random.uniform(
                    y0 = 10**np.random.uniform(
                    # y0 = np.random.choice(yy[idx])

                    temp = [grid[ll][def_idx[ll]] for ll in range(num)]
                    temp[ii] = y0
                    temp[jj] = x0

                    if temp[2] >= temp[3]:
                        temp[2] = 3.1
                    iv = interp(temp)
                    if not np.isfinite(iv) or np.isclose(iv, 0.0):

                        for kk in range(num):
                            if def_idx[kk] == 0:
                                temp[kk] = temp[kk] * 1.11
                            elif def_idx[kk] == -1:
                                temp[kk] = 0.99 * temp[kk]

                        iv = interp(temp)
                        print("\t", temp)
                        print("\t", iv)

                    cc = smap.to_rgba(iv)
                    ss = 20
                    ax.scatter(temp[jj], temp[ii], color='0.5', s=2 * ss)
                    ax.scatter(temp[jj], temp[ii], color=cc, s=ss)

        if ii == num - 1:
        if jj == 0 and ii != 0:

    return fig
Example #11
def colormap(args=[0.0, 1.0],
    """Create a colormap from a scalar range to a set of colors.

    args : scalar or array_like of scalar
        Range of valid scalar values to normalize with
    cmap : ``matplotlib.colors.Colormap`` object
        Colormap to use.
    scale : str or `None`
        Scaling specification of colormap {'lin', 'log', `None`}.
        If `None`, scaling is inferred based on input `args`.
    under : str or `None`
        Color specification for values below range.
    over : str or `None`
        Color specification for values above range.
    left : float {0.0, 1.0} or `None`
        Truncate the left edge of the colormap to this value.
        If `None`, 0.0 used (if `right` is provided).
    right : float {0.0, 1.0} or `None`
        Truncate the right edge of the colormap to this value
        If `None`, 1.0 used (if `left` is provided).

    smap : ``matplotlib.cm.ScalarMappable``
        Scalar mappable object which contains the members:
        `norm`, `cmap`, and the function `to_rgba`.

    -   Truncation:
        -   If neither `left` nor `right` is given, no truncation is performed.
        -   If only one is given, the other is set to the extreme value: 0.0 or 1.0.


    if cmap is None:
        cmap = 'jet'
    if isinstance(cmap, six.string_types):
        cmap = plt.get_cmap(cmap)

    # Select a truncated subsection of the colormap
    if (left is not None) or (right is not None):
        if left is None:
            left = 0.0
        if right is None:
            right = 1.0
        cmap = cut_colormap(cmap, left, right)

    if under is not None:
    if over is not None:

    if scale is None:
        if np.size(args) > 1 and np.all(args > 0.0):
            scale = 'log'
            scale = 'lin'

    log = _scale_to_log_flag(scale)
    if log:
        filter = 'g'
        filter = None

    # Determine minimum and maximum
    if np.size(args) > 1:
        rv = zmath.minmax(args, filter=filter)
        if rv is None:
            min, max = 0.0, 0.0
            min, max = rv
    elif np.size(args) == 1:
        min, max = 0, np.int(args) - 1
    elif np.size(args) == 2:
        min, max = args
        min, max = 0.0, 0.0

    # Create normalization
    if log:
        norm = mpl.colors.LogNorm(vmin=min, vmax=max)
        norm = mpl.colors.Normalize(vmin=min, vmax=max)

    # Create scalar-mappable
    smap = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
    # Bug-Fix something something
    smap._A = []
    # Store type of mapping
    smap.log = log

    return smap
Example #12
def set_axis(ax,
    Configure a particular axis of the given axes object.

       ax     : <matplotlib.axes.Axes>, base axes object to modify
       axis   : <str>, which axis to target {``x`` or ``y``}
       color      : <str>, color for the axis (see ``matplotlib.colors``)
       fs     : <int>, font size for labels
       pos    : <float>, position of axis-label/lines relative to the axes object
       trans  : <str>, transformation type for the axes
       label  : <str>, axes label (``None`` means blank)
       scale  : <str>, axis scale, e.g. 'log', (``None`` means default)
       thresh : <float>, for 'symlog' scaling, the threshold for the linear segment
       side   : <str>, where to place the markings, {``left``, ``right``, ``top``, ``bottom``}
       ts     : <int>, tick-size (for the major ticks only)
       grid   : <bool>, whether grid lines should be enabled
       lim    : <float>[2], limits for the axis range
       invert : <bool>, whether to invert this axis direction (i.e. high to low)
       stretch : <flt>,


    assert axis in ['x', 'y'], "``axis`` must be `x` or `y`!"
    assert trans in ['axes', 'figure'], "``trans`` must be `axes` or `figure`!"
    assert side in VALID_SIDES, "``side`` must be in '%s'" % (VALID_SIDES)

    color = _color_from_kwargs(kwargs, pop=True)
    if color is None:
        color = 'k'

    if len(kwargs) > 0:
        raise ValueError("Additional arguments are not supported!")

    # Set tick colors and font-sizes
    kw = {}
    if fs is not None:
        kw['labelsize'] = fs
    ax.tick_params(axis=axis, which='both', colors=color, **kw)
    #    Set tick-size only for major ticks
    # ax.tick_params(axis=axis, which='major')

    # Set Grid Lines
    set_grid(ax, grid, axis='both')

    if axis == 'x':
        offt = ax.get_xaxis().get_offset_text()

        if side is None:
            if pos is None:
                side = 'bottom'
                if pos < 0.5:
                    side = 'bottom'
                    side = 'top'

        if pos is not None:

        if lim is not None:
            if np.size(lim) > 2:
                lim = zmath.minmax(lim)

        if invert:
        if not ticks:
            for tlab in ax.xaxis.get_ticklabels():

        offt = ax.get_yaxis().get_offset_text()

        if side is None:
            if pos is None:
                side = 'left'
                if pos < 0.5:
                    side = 'left'
                    side = 'right'

        if pos is not None:


        if lim is not None:

        if invert:
        if not ticks:
            for tlab in ax.yaxis.get_ticklabels():

    # Set Spine colors
    if pos is not None:
        ax.spines[side].set_position((trans, pos))

    # Set Axis Scaling
    if scale is not None:
        _setAxis_scale(ax, axis, scale, thresh=thresh)

    # Set Axis Label
    if label is not None:
        kw = {}
        if fs is not None:
            kw['fs'] = fs
        _setAxis_label(ax, axis, label, color=color, **kw)

    if not np.isclose(stretch, 1.0):
        if axis == 'x':
            ax = stretchAxes(ax, xs=stretch)
        elif axis == 'y':
            ax = stretchAxes(ax, ys=stretch)

    return ax
Example #13
def plot2DHistProj(xvals, yvals, weights=None, statistic=None, bins=10, filter=None, extrema=None,
                   fig=None, xproj=True, yproj=True, hratio=0.7, wratio=0.7, pad=0.0, alpha=1.0,
                   cmap=None, smap=None, type='hist', scale_to_cbar=True,
                   fs=12, scale='log', histScale='log', labels=None, cbar=True,
                   overlay=None, overlay_fmt=None,
                   left=_LEFT, bottom=_BOTTOM, right=_RIGHT, top=_TOP, lo=None, hi=None,
                   overall=False, overall_bins=20, overall_wide=False, overall_cumulative=False):
    """Plot a 2D histogram with projections of one or both axes.

    xvals : (N,) array_like,
        Values corresponding to the x-points of the given data
    yvals : (N,) array_like,
        Values corresponding to the y-points of the given data
    weights : (N,) array_like or `None`,
        Weights used to create histograms.  If `None`, then counts are used.
    statistic : str or `None`,
        Type of statistic to be calculated, passed to ``scipy.stats.binned_statistic``.
        e.g. {'count', 'sum', 'mean'}.
        If `None`, then either 'sum' or 'count' is used depending on if `weights` are
        provieded or not.
    bins : int or [int, int] or array_like or [array, array],
        Specification for bin sizes.  integer values are treated as the number of bins to use,
        while arrays are used as the bin edges themselves.  If a tuple of two values is given, it
        is assumed that the first is for the x-axis and the second for the y-axis.
    filter : str or `None`, or [2,] tuple of str or `None`, or [3,] tubple of str or `None`
        String specifying how to filter the input `data` relative to zero.
        If this is a single value, it is applies to both `xvals` and `yvals`.
        If this is a tuple/list of two values, they correspond to `xvals` and `yvals` respectively.
        If `weights` are provided, the tuple/list should have three values.
    extrema :
    cumulative :
    fig : ``matplotlib.figure.figure``,
        Figure instance to which axes are added for plotting.  One is created if not given.
    xproj : bool,
        Whether to also plot the projection of the x-axis (i.e. histogram ignoring y-values).
    yproj : bool,
        Whether to also plot the projection of the y-axis (i.e. histogram ignoring x-values).
    hratio : float,
        Fraction of the total available height-space to use for the primary axes object (2D hist)
    wratio : float,
        Fraction of the total available width-space to use for the primary axes object (2D hist)
    pad : float,
        Padding between central axis and the projected ones.
    cmap : ``matplotlib.colors.Colormap`` object
        Matplotlib colormap to use for coloring histogram.
        Overridden if `smap` is provided.
    smap : `matplotlib.cm.ScalarMappable` object or `None`
        A scalar-mappable object to use for colormaps, or `None` for one to be created.
    type : str, {'hist', 'scatter'}
        What type of plot should be in the center, a 2D Histogram or a scatter-plot.
    scale_to_cbar :
    fs : int,
    scale : str or [str, str],
        Specification for the axes scaling {'log','lin'}.  If two values are given, the first is
        used for the x-axis and the second for the y-axis.
    histScale : str,
        Scaling to use for the histograms {'log','lin'}-- the color scale on the 2D histogram,
        or the Counts axis on the 1D histograms.
    labels : (2,) str
    cbar : bool,
        Add a colorbar.
    overlay : str or 'None', if str {'counts', 'values'}
        Print a str on each bin writing,
        'counts' - the number of values included in that bin, or
        'values' - the value of the bin itself.
    overlay_fmt : str or 'None'
        Format specification on overlayed values, e.g. "02d" (no colon or brackets).
    left : float {0.0, 1.0}
        Location of the left edge of axes relative to the figure.
    bottom : float {0.0, 1.0}
        Location of the bottom edge of axes relative to the figure.
    right : float {0.0, 1.0}
        Location of the right edge of axes relative to the figure.
    top : float {0.0, 1.0}
        Location of the top edge of axes relative to the figure.
    lo : scalar or 'None'
        When autocalculating `extrema`, ignore histogram entries below this value.
    hi : scalar or 'None'
        When autocalculating `extrema`, ignore histogram entries above this value.
    overall :

    fig : matplotlib.figure.Figure
        Figure object containing plots.

    # Make sure shapes of input arrays are valid
    if np.shape(xvals) != np.shape(yvals):
        raise ValueError("Shape of `xvals` ({}) must match `yvals` ({}).".format(
            np.shape(xvals), np.shape(yvals)))
    if weights is not None and np.shape(weights) != np.shape(xvals):
        raise ValueError("Shape of `weights` ({}) must match `xvals` and `yvals` ({}).".format(
            np.shape(weights), np.shape(xvals)))

    if overlay is not None:
        if not (overlay.startswith('val') or overlay.startswith('count')):
            raise ValueError("`overlay` = '{}', must be {'values', 'count'}".format(overlay))

    # Make sure the given `scale` is valid
    if np.size(scale) == 1:
        scale = [scale, scale]
    elif np.size(scale) != 2:
        raise ValueError("`scale` must be one or two scaling specifications!")

    # Check the `labels`
    if labels is None:
        labels = ['', '', '']
    elif np.size(labels) == 2:
        labels = [labels[0], labels[1], '']

    if np.size(labels) != 3:
        raise ValueError("`labels` = '{}' is invalid.".format(labels))

    # Make sure scale strings are matplotlib-compliant
    scale = [plot_core._clean_scale(sc) for sc in scale]

    # Determine type of central plot
    if type.startswith('hist'):
        type_hist = True
    elif type.startswith('scat'):
        type_hist = False
        cblabel = str(labels[2])
        labels[2] = 'Count'
        raise ValueError("`type` = '{}', must be either 'hist', or 'scatter'.".format(type))

    # Infer default statistic
    if statistic is None:
        if weights is None: statistic = 'count'
        else:               statistic = 'sum'

    if filter is None and histScale.startswith('log'):
        filter = 'g'

    # Filter input data
    if filter is not None:
        # Make sure `filter` is an iterable pair
        # if weights is None:
        #     num = 2
        # else:
        #     num = 3

        if not np.iterable(filter):
            filter = 3*[filter]
        elif len(filter) == 1:
            filter = 3*[filter[0]]

        # if len(filter) != num:
        #     raise ValueError("If `weights` are provided, number of `filter` values must match.")

        # Filter `xvals`
        if filter[0] is not None:
            inds = zmath.comparison_filter(xvals, filter[0], inds=True)
            xvals = xvals[inds]
            yvals = yvals[inds]
            if weights is not None:
                weights = weights[inds]
        # Filter `yvals`
        if filter[1] is not None:
            inds = zmath.comparison_filter(yvals, filter[1], inds=True)
            xvals = xvals[inds]
            yvals = yvals[inds]
            if weights is not None:
                weights = weights[inds]

        if weights is not None and filter[2] is not None:
            inds = zmath.comparison_filter(yvals, filter[2], inds=True)
            xvals = xvals[inds]
            yvals = yvals[inds]
            weights = weights[inds]

    # Create and initializae figure and axes
    fig, prax, xpax, ypax, cbax, ovax = _constructFigure(
        fig, xproj, yproj, overall, overall_wide, hratio, wratio, pad,
        scale, histScale, labels, cbar,
        left, bottom, right, top, fs=fs)

    # Create bins
    # -----------
    #     `bins` is a single scalar value -- apply to both
    if np.isscalar(bins):
        xbins = bins
        ybins = bins
        #     `bins` is a pair of bin specifications, separate and apply
        if len(bins) == 2:
            xbins = bins[0]
            ybins = bins[1]
        #     `bins` is a single array -- apply to both
        elif len(bins) > 2:
            xbins = bins
            ybins = bins
        #     unrecognized option -- error
            raise ValueError("Unrecognized shape of ``bins`` = %s" % (str(np.shape(bins))))

    # If a number of bins is given, create an appropriate spacing
    if np.ndim(xbins) == 0:
        xbins = zmath.spacing(xvals, num=xbins+1, scale=scale[0])

    if np.ndim(ybins) == 0:
        ybins = zmath.spacing(yvals, num=ybins+1, scale=scale[1])

    # Make sure bins look okay
    for arr, name in zip([xbins, ybins], ['xbins', 'ybins']):
        delta = np.diff(arr)
        if np.any(~np.isfinite(delta) | (delta == 0.0)):
            raise ValueError("Error constructing `{}` = {}, delta = {}".format(name, arr, delta))

    # Calculate Histograms
    # --------------------
    #    2D
        hist_2d, xedges_2d, yedges_2d, binnums_2d = sp.stats.binned_statistic_2d(
            xvals, yvals, weights, statistic=statistic, bins=[xbins, ybins], expand_binnumbers=True)
        hist_2d = np.nan_to_num(hist_2d)
        #    X-projection (ignore Y)
        hist_xp, edges_xp, bins_xp = sp.stats.binned_statistic(
            xvals, weights, statistic=statistic, bins=xbins)
        #    Y-projection (ignore X)
        hist_yp, edges_yp, bins_yp = sp.stats.binned_statistic(
            yvals, weights, statistic=statistic, bins=ybins)
        hist_2d, xedges_2d, yedges_2d, binnums_2d = sp.stats.binned_statistic_2d(
            xvals, yvals, weights, statistic=statistic, bins=[xbins, ybins])
        hist_2d = np.nan_to_num(hist_2d)
        #    X-projection (ignore Y)
        hist_xp, edges_xp, bins_xp = sp.stats.binned_statistic(
            xvals, weights, statistic=statistic, bins=xbins)
        #    Y-projection (ignore X)
        hist_yp, edges_yp, bins_yp = sp.stats.binned_statistic(
            yvals, weights, statistic=statistic, bins=ybins)

    if cumulative is not None:
        hist_2d = _cumulative_stat2d(
            weights, hist_2d.shape, binnums_2d, statistic, cumulative)
        hist_xp = _cumulative_stat1d(
            weights, hist_xp.size, bins_xp, statistic, cumulative[0])
        hist_yp = _cumulative_stat1d(
            weights, hist_yp.size, bins_yp, statistic, cumulative[1])

    # Calculate Extrema - Preserve input extrema if given, otherwise calculate
    extrema = _set_extrema(extrema, [hist_2d, hist_xp, hist_yp], filter=filter[2], lo=lo, hi=hi)
    # Create scalar-mappable if needed
    if smap is None:
        smap = plot_core.colormap(extrema, cmap=cmap, scale=histScale)

    # Plot Histograms and Projections
    # -------------------------------
    # Plot 2D Histogram
    if type_hist:
        overlay_values = None
        # If we should overlay strings labeling the num values in each bin, calculate those `counts`
        if overlay is not None:
            # Overlay the values themselves
            if overlay.startswith('val'):
                overlay_values = hist_2d
                if overlay_fmt is None:
                    overlay_fmt = ''
            # Get the 'counts' to overlay on plot
                if overlay_fmt is None:
                    overlay_fmt = 'd'
                    overlay_values, xedges_2d, yedges_2d, binnums = sp.stats.binned_statistic_2d(
                        xvals, yvals, weights, statistic='count', bins=[xbins, ybins],
                    overlay_values, xedges_2d, yedges_2d, binnums = sp.stats.binned_statistic_2d(
                        xvals, yvals, weights, statistic='count', bins=[xbins, ybins])

                if cumulative is not None:
                    overlay_values = _cumulative_stat2d(
                        np.ones_like(xvals), overlay_values.shape, binnums, 'count', cumulative)

                overlay_values = overlay_values.astype(int)

        pcm, smap, cbar, cs = plot2DHist(prax, xedges_2d, yedges_2d, hist_2d, cscale=histScale,
                                         cbax=cbax, labels=labels, cmap=cmap, smap=smap,
                                         extrema=extrema, fs=fs, scale=scale,
                                         overlay=overlay_values, overlay_fmt=overlay_fmt)

        # Colors
        # X-projection
        if xpax:
            colhist_xp = np.array(hist_xp)
            # Enforce positive values for colors in log-plots.
            if smap.log:
                tmin, tmax = zmath.minmax(colhist_xp, filter='g')
                colhist_xp = np.maximum(colhist_xp, tmin)
            colors_xp = smap.to_rgba(colhist_xp)

        if ypax:
            colhist_yp = np.array(hist_yp)
            # Enforce positive values for colors in log-plots.
            if smap.log:
                tmin, tmax = zmath.minmax(colhist_yp, filter='g')
                colhist_xp = np.maximum(colhist_yp, tmin)
            colors_yp = smap.to_rgba(colhist_yp)

        # colors_yp = smap.to_rgba(hist_yp)

    # Scatter Plot
        colors = smap.to_rgba(weights)
        prax.scatter(xvals, yvals, c=colors, alpha=alpha)

        if cbar:
            cbar = plt.colorbar(smap, cax=cbax)
            cbar.set_label(cblabel, fontsize=fs)

        # Make projection colors all grey
        colors_xp = '0.8'
        colors_yp = '0.8'

    hist_log = plot_core._scale_to_log_flag(histScale)

    # Plot projection of the x-axis (i.e. ignore 'y')
    if xpax:
        islog = scale[0].startswith('log')

        xpax.bar(edges_xp[:-1], hist_xp, color=colors_xp, log=islog, width=np.diff(edges_xp),
        # set tick-labels to the top
        plt.setp(xpax.get_yticklabels(), fontsize=fs)
        # set bounds to bin edges
        plot_core.set_lim(xpax, 'x', data=xedges_2d)
        # Set axes limits to match those of colorbar
        if scale_to_cbar:
            # extrema_y = [zmath.floor_log(extrema[0]), zmath.ceil_log(extrema[1])]
            round = 0
            # if hist_log: round = -1
            extrema_y = zmath.minmax(extrema, round=round)

    # Plot projection of the y-axis (i.e. ignore 'x')
    if ypax:
        ypax.barh(edges_yp[:-1], hist_yp, color=colors_yp, log=hist_log, height=np.diff(edges_yp),
        # set tick-labels to the top
        plt.setp(ypax.get_yticklabels(), fontsize=fs, rotation=90)
        # set bounds to bin edges
        plot_core.set_lim(ypax, 'y', data=yedges_2d)
            ypax.locator_params(axis='x', tight=True, nbins=4)
            ypax.locator_params(axis='x', tight=True)

        # Set axes limits to match those of colorbar
        if scale_to_cbar:
            round = 0
            # if hist_log: round = -1
            extrema_x = zmath.minmax(extrema, round=round)

    # Plot Overall Histogram of values
    if overall:
        ov_bins = zmath.spacing(weights, num=overall_bins)
        bin_centers = zmath.midpoints(ov_bins, log=hist_log)
        nums, bins, patches = ovax.hist(weights, ov_bins, log=hist_log, facecolor='0.5', edgecolor='k')
        for pp, cent in zip(patches, bin_centers):

        # Add cumulative distribution
        if overall_cumulative:
            cum_sum = np.cumsum(nums)
            ovax.plot(bin_centers, cum_sum, 'k--')

    prax.set(xlim=zmath.minmax(xedges_2d), ylim=zmath.minmax(yedges_2d))

    return fig
def _mergeUnique(snaps, old_ids, old_scales, new_data, log):

    new_snap = new_data[DETAILS.SNAP]
    new_ids = new_data[DETAILS.IDS]
    new_scales = new_data[DETAILS.SCALES]

    n_old = old_ids.size
    n_new = new_ids.size
    log.debug(" - %d so far, Snap %d with %d entries" % (n_old, new_snap, n_new))

    if np.isscalar(snaps):
        old = snaps
        snaps = n_old * [None]
        for ii in range(n_old):
            snaps[ii] = [old]

    oo = 0
    nn = 0
    for ii, (nn, ss) in enumerate(zip(new_ids, new_scales)):
        # Update iterator in old array to reach at least this ID number
        while oo < n_old-1 and old_ids[oo] < nn:
            oo += 1
        # If new ID is already in old list, add to snap-list, modify first/last scales
        if old_ids[oo] == nn:
            #     Find new extrema
            old_scales[oo] = zmath.minmax([old_scales[oo], ss])
        # If new ID not in old list, add new entry
            ins = oo
            # If we need to insert this new value as the last element,
            #    have to do the final incrementation manually
            if oo == n_old-1 and nn > old_ids[oo]:
                ins = oo+1
            old_ids = np.insert(old_ids, ins, nn, axis=0)
            old_scales = np.insert(old_scales, ins, ss, axis=0)
            snaps.insert(ins, [new_snap])
            #    Update length
            n_old += 1

    if old_ids.dtype.type is not np.uint64:
        print(("types = ", old_ids.dtype.type, new_ids.dtype.type))
        raise RuntimeError("old_ids at snap = %d is non-integral!" % (new_snap))

    if new_ids.dtype.type is not np.uint64:
        print(("types = ", old_ids.dtype.type, new_ids.dtype.type))
        raise RuntimeError("new_ids at snap = %d is non-integral!" % (new_snap))

    # Make sure things seem right
    test_ids = np.hstack([old_ids, new_ids])
    test_ids = np.array(list(set(test_ids)))
    n_test = test_ids.size
    if len(old_ids) != n_test or n_test != len(old_scales) or n_test != len(snaps):
        dups = Counter(old_ids) - Counter(test_ids)
        print("Duplicates = %s" % (str(list(dups.keys()))))
        errStr = "ERROR: num unique should be %d" % (n_test)
        errStr += "\nBut len(old_ids) = %d" % (len(old_ids))
        errStr += "\nBut len(old_scales) = %d" % (len(old_scales))
        errStr += "\nBut len(snaps) = %d" % (len(snaps))
        raise RuntimeError(errStr)

    return snaps, old_ids, old_scales