Exemple #1
0
def getZfromOrigins(origins, star_pars):
    if type(origins) is str:
        origins = dt.loadGroups(origins)
    if type(star_pars) is str:
        star_pars = dt.loadXYZUVW(star_pars)
    nstars = star_pars['xyzuvw'].shape[0]
    ngroups = len(origins)
    nassoc_stars = np.sum([o.nstars for o in origins])
    using_bg = nstars != nassoc_stars
    z = np.zeros((nstars, ngroups + using_bg))
    stars_so_far = 0
    # set associaiton members memberships to 1
    for i, o in enumerate(origins):
        z[stars_so_far:stars_so_far+o.nstars, i] = 1.
        stars_so_far += o.nstars
    # set remaining stars as members of background
    if using_bg:
        z[stars_so_far:,-1] = 1.
    return z
    print("Checking {}".format(plt_file))
    if not os.path.isfile(plt_file):
        print("Plotting {}".format(plt_file))
        try:
            star_pars_file = rdir + 'xyzuvw_now.fits'
            chain_file = rdir + 'final_chain.npy'
            origins_file = rdir + 'origins.npy'
            lnprob_file = rdir + 'final_lnprob.npy'

            chain = np.load(chain_file).reshape(-1, 9)
            lnprob = np.load(lnprob_file)
            best_pars = chain[np.argmax(lnprob_file)]
            best_fit = chronostar.component.Component(best_pars, internal=True)
            origins = dt.loadGroups(origins_file)

            star_pars = dt.loadXYZUVW(star_pars_file)

            fp.plotMultiPane(
                ['xy', 'xz', 'uv', 'xu', 'yv', 'zw'],
                star_pars,
                [best_fit],
                origins=origins,
                save_file=rdir +
                'multi_plot_{}_{}_{}_{}_{}_{}.pdf'.format(*scenario),
                title='{}Myr, {}pc, {}km/s, {} stars, {}, {}'.format(
                    *scenario),
            )
            print("done")
        except:
            print("Not ready yet...")
)

# Initialize the MPI-based pool used for parallelization.
using_mpi = True
try:
    pool = MPIPool()
    logging.info("Successfully initialised mpi pool")
except:
    #print("MPI doesn't seem to be installed... maybe install it?")
    logging.info("MPI doesn't seem to be installed... maybe install it?")
    using_mpi = False
    pool=None



star_pars = dt.loadXYZUVW(xyzuvw_file)

np.set_printoptions(suppress=True)

# --------------------------------------------------------------------------
# Get grid-based Z membership
# --------------------------------------------------------------------------

# Grid
xyzuvw = star_pars['xyzuvw']
dmin=np.min(xyzuvw, axis=0)-1
dmax=np.max(xyzuvw, axis=0)+1
grid_u = np.linspace(dmin[3], dmax[3], 10) #[-10000, -60, -20, 20, 60, 10000]
grid_v = np.linspace(dmin[4], dmax[4], 10)
grid_w = np.linspace(dmin[5], dmax[5], 6)
Exemple #4
0
    using_mpi = False
    pool=None

if using_mpi:
    if not pool.is_master():
        print("One thread is going to sleep")
        # Wait for instructions from the master process.
        pool.wait()
        sys.exit(0)
print("Only one thread is master")

#logging.info(path_msg)

print("Master should be working in the directory:\n{}".format(rdir))

star_pars = dt.loadXYZUVW(xyzuvw_file)

# GET BACKGROUND LOG OVERLAP DENSITIES:
logging.info("Acquiring background overlaps")
try:
    print("In try")
    logging.info("Calculating background overlaps")
    logging.info(" -- this step employs scipy's kernel density estimator")
    logging.info(" -- so could take a few minutes...")
    bg_ln_ols = np.load(bg_ln_ols_file)
    print("could load")
    assert len(bg_ln_ols) == len(star_pars['xyzuvw'])
    logging.info("Loaded bg_ln_ols from file")
except (IOError, AssertionError):
    print("in except")
    bg_ln_ols = dt.getKernelDensities(gaia_xyzuvw_file, star_pars['xyzuvw'], amp_scale=0.001)
Exemple #5
0
#
#
# # make sure stars are initialised as expected
# if can_plot:
#     for dim1, dim2 in ('xy', 'xu', 'yv', 'zw', 'uv'):
#         plt.clf()
#         true_memb = dt.getZfromOrigins(origins, star_pars)
#         fp.plotPaneWithHists(dim1,dim2,groups=origins,
#                              weights=[origin.nstars for origin in origins],
#                              star_pars=star_pars,
#                              group_now=True,
#                              membership=true_memb,
#                              true_memb=true_memb)
#         plt.savefig(rdir + 'pre_plot_{}{}.pdf'.format(dim1, dim2))

star_pars = dt.loadXYZUVW(xyzuvw_conv_savefile)

MAX_COMP = 6
ncomps = 1

# Set up initial values of results
prev_groups = None
prev_meds = None
prev_lnpost = -np.inf
prev_BIC = np.inf
prev_lnlike = -np.inf
prev_z = None
#
# # Initialise z
# nstars = star_pars['xyzuvw'].shape[0]
# nassoc_stars = np.sum([o.nstars for o in origins])
"""
from __future__ import print_function, division

import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import sys
sys.path.insert(0, '..')
import chronostar.retired2.datatool as dt

MARGIN = 2.0

assoc_name = "bpmg_cand_w_gaia_dr2_astrometry_comb_binars"
xyzuvw_file = "../data/{}_xyzuvw.fits".format(assoc_name)
xyzuvw_dict = dt.loadXYZUVW(xyzuvw_file)

# construct box
kin_max = np.max(xyzuvw_dict['xyzuvw'], axis=0)
kin_min = np.min(xyzuvw_dict['xyzuvw'], axis=0)
span = kin_max - kin_min
upper_boundary = kin_max + MARGIN*span
lower_boundary = kin_min - MARGIN*span


# get gaia stars within box
gaia_xyzuvw_file = "../data/gaia_dr2_mean_xyzuvw.npy"
gaia_xyzuvw = np.load(gaia_xyzuvw_file)
mask = np.where(
    np.all(
        (gaia_xyzuvw < upper_boundary) & (gaia_xyzuvw > lower_boundary),
Exemple #7
0
def plotMultiPane(dim_pairs, star_pars, groups, origins=None,
                  save_file='dummy.pdf', title=None):
    """
    Flexible function that plots many 2D slices through data and fits

    Takes as input a list of dimension pairs, stellar data and fitted
    groups, and will plot each dimension pair in a different pane.

    TODO: Maybe add functionality to control the data plotted in each pane

    Parameters
    ----------
    dim_pairs: a list of dimension pairs e.g.
        [(0,1), (3,4), (0,3), (1,4), (2,5)]
        ['xz', 'uv', 'zw']
        ['XY', 'UV', 'XU', 'YV', 'ZW']
    star_pars: either
        dicitonary of stellar data with keys 'xyzuvw' and 'xyzuvw_cov'
            or
        string filename to saved data
    groups: either
        a single synthesiser.Group object,
        a list or array of synthesiser.Group objects,
            or
        string filename to data saved as '.npy' file
    save_file: string, name (and path) of saved plot figure

    Returns
    -------
    (nothing)
    """

    # Tidying up inputs
    if type(star_pars) is str:
        star_pars = dt.loadXYZUVW(star_pars)
    if type(groups) is str:
        groups = np.load(groups)
        # handle case where groups is a single stored object
        if len(groups.shape) == 0:
            groups = groups.item()
    # ensure groups is iterable
    try:
        len(groups)
    except:  # groups is a single group instance
        groups = [groups]

    if origins:
        try:
            len(origins)
        except: # origins is a single group instance
            origins = [origins]

    # setting up plot dimensions
    npanes = len(dim_pairs)
    rows = int(np.sqrt(npanes)) #plots are never taller than wide
    cols = (npanes + rows - 1) // rows  # get enough cols
    ax_h = 5
    ax_w = 5
    f, axs = plt.subplots(rows, cols)
    f.set_size_inches(ax_w * cols, ax_h * rows)

    # drawing each axes
    for i, (dim1, dim2) in enumerate(dim_pairs):
        plotPane(dim1, dim2, axs.flatten()[i], groups=groups, origins=origins,
                 star_pars=star_pars, star_orbits=False,
                 group_then=True, group_now=True, group_orbit=True,
                 annotate=False)

    if title:
        f.suptitle(title)
    if save_file:
        f.savefig(save_file, format='pdf')
#
#
# # make sure stars are initialised as expected
# if can_plot:
#     for dim1, dim2 in ('xy', 'xu', 'yv', 'zw', 'uv'):
#         plt.clf()
#         true_memb = dt.getZfromOrigins(origins, star_pars)
#         fp.plotPaneWithHists(dim1,dim2,groups=origins,
#                              weights=[origin.nstars for origin in origins],
#                              star_pars=star_pars,
#                              group_now=True,
#                              membership=true_memb,
#                              true_memb=true_memb)
#         plt.savefig(rdir + 'pre_plot_{}{}.pdf'.format(dim1, dim2))

star_pars = dt.loadXYZUVW(xyzuvw_conv_savefile)

MAX_COMP = 6
ncomps = 1

# Set up initial values of results
prev_groups = None
prev_meds = None
prev_lnpost = -np.inf
prev_BIC = np.inf
prev_lnlike = -np.inf
prev_z = None
#
# # Initialise z
# nstars = star_pars['xyzuvw'].shape[0]
# nassoc_stars = np.sum([o.nstars for o in origins])
Exemple #9
0
def plotPaneWithHists(dim1,
                      dim2,
                      fignum=None,
                      groups=[],
                      weights=None,
                      star_pars=None,
                      star_orbits=False,
                      group_then=False,
                      group_now=False,
                      group_orbit=False,
                      annotate=False,
                      bg_hists=None,
                      membership=None,
                      true_memb=None,
                      savefile='',
                      with_bg=False,
                      range_1=None,
                      range_2=None,
                      residual=False,
                      markers=None,
                      group_bg=False,
                      isotropic=False,
                      color_labels=[],
                      marker_labels=[],
                      marker_order=[],
                      ordering=None,
                      no_bg_covs=False):
    """
    Plot a 2D projection of data and fit along with flanking 1D projections.

    Uses global constants COLORS and HATCHES to inform consistent colour
    scheme.
    Can use this to plot different panes of one whole figure

    TODO: Incorporate Z
    TODO: incoporate background histogram

    Parameters
    ----------
    dim1: x-axis, can either be integer 0-5 (inclusive) or a letter form
          'xyzuvw' (either case)
    dim2: y-axis, same conditions as dim1
    fignum: figure number in which to create the plot
    groups: a list of (or just one) synthesiser.Group objects, corresponding
            to the fit of the origin(s)
    star_pars:  dict object with keys 'xyzuvw' ([nstars,6] array of current
                star means) and 'xyzuvw_cov' ([nstars,6,6] array of current
                star covariance matrices)
    star_orbits: (bool) plot the calculated stellar traceback orbits of
                        central estimate of measurements
    group_then: (bool) plot the group's origin
    group_now:  (bool) plot the group's current day distribution
    group_orbit: (bool) plot the trajectory of the group's mean
    annotate: (bool) add text describing the figure's contents
    with_bg: (bool) treat the final column of Z as background memberships
             and color accordingly

    Returns
    -------
    (nothing returned)
    """
    labels = 'XYZUVW'
    axes_units = 3 * ['pc'] + 3 * ['km/s']
    if type(membership) is str:
        membership = np.load(membership)
    if type(star_pars) is str:
        star_pars = dt.loadXYZUVW(star_pars)
    # if ordering:
    #     membership = membership[:,ordering]

    # TODO: clarify what exactly you're trying to do here
    if weights is None and len(groups) > 0:
        if len(groups) == 1 and not with_bg:
            weights = np.array([len(star_pars['xyzuvw'])])
        elif membership is not None:
            weights = membership.sum(axis=0)
        else:
            weights = np.ones(len(groups)) / len(groups)

    if type(dim1) is not int:
        dim1 = labels.index(dim1.upper())
    if type(dim2) is not int:
        dim2 = labels.index(dim2.upper())
    if type(groups) is str:
        groups = np.load(groups)
        if len(groups.shape) == 0:
            groups = np.array(groups.item())
    if type(bg_hists) is str:
        bg_hists = np.load(bg_hists)

    # Set global plt tick params???
    tick_params = {'direction': 'in', 'top': True, 'right': True}
    plt.tick_params(**tick_params)

    # Set up plot
    fig_width = 5  #inch
    fig_height = 5  #inch
    fig = plt.figure(fignum, figsize=(fig_width, fig_height))
    plt.clf()
    # gs = gridspec.GridSpec(4, 4)
    gs = gridspec.GridSpec(4, 4)

    # Set up some global plot features
    # fig.set_tight_layout(tight=True)
    plt.figure()

    # Plot central pane
    axcen = plt.subplot(gs[1:, :-1])
    xlim, ylim = plotPane(dim1,
                          dim2,
                          ax=axcen,
                          groups=groups,
                          star_pars=star_pars,
                          star_orbits=star_orbits,
                          group_then=group_then,
                          group_now=group_now,
                          group_orbit=group_orbit,
                          annotate=annotate,
                          membership=membership,
                          true_memb=true_memb,
                          with_bg=with_bg,
                          markers=markers,
                          group_bg=group_bg,
                          isotropic=isotropic,
                          range_1=range_1,
                          range_2=range_2,
                          marker_labels=marker_labels,
                          color_labels=color_labels,
                          ordering=ordering,
                          no_bg_covs=no_bg_covs)
    plt.tick_params(**tick_params)
    # if range_1:
    #     plt.xlim(range_1)
    # if range_2:
    #     plt.ylim(range_2)
    # plt.grid(gridsepc_kw={'wspace': 0, 'hspace': 0})
    # plt.sharex(True)

    # Plot flanking 1D projections
    # xlim = axcen.get_xlim()
    axtop = plt.subplot(gs[0, :-1])
    axtop.set_xlim(xlim)
    axtop.set_xticklabels([])
    plot1DProjection(dim1,
                     star_pars,
                     groups,
                     weights,
                     ax=axtop,
                     bg_hists=bg_hists,
                     with_bg=with_bg,
                     membership=membership,
                     residual=residual,
                     x_range=xlim)
    axtop.set_ylabel('Stars per {}'.format(axes_units[dim1]))
    plt.tick_params(**tick_params)
    # axcen.set_tick_params(direction='in', top=True, right=True)

    # ylim = axcen.get_ylim()
    axright = plt.subplot(gs[1:, -1])
    axright.set_ylim(ylim)
    axright.set_yticklabels([])
    plot1DProjection(dim2,
                     star_pars,
                     groups,
                     weights,
                     ax=axright,
                     bg_hists=bg_hists,
                     horizontal=True,
                     with_bg=with_bg,
                     membership=membership,
                     residual=residual,
                     x_range=ylim)
    axright.set_xlabel('Stars per {}'.format(axes_units[dim2]))
    # axcen.set_tick_params(direction='in', top=True, right=True)
    plt.tick_params(**tick_params)
    # plt.tight_layout(pad=0.7)

    axleg = plt.subplot(gs[0, -1])
    for spine in axleg.spines.values():
        spine.set_visible(False)
    axleg.tick_params(labelbottom='off',
                      labelleft='off',
                      bottom='off',
                      left='off')
    # import pdb; pdb.set_trace()

    if False:
        for label_ix, marker_ix in enumerate(marker_order):
            axleg.scatter(0,
                          0,
                          color='black',
                          marker=MARKERS[marker_ix],
                          label=MARKER_LABELS[label_ix])
        for i, color_label in enumerate(color_labels):
            axleg.plot(0, 0, color=COLORS[i], label=color_label)
        axleg.legend(loc='best', framealpha=1.0)

    # for i, marker_label in enumerate(marker_labels):
    #     axleg.scatter(0,0,color='black',marker=MARKERS[i],label=marker_label)
    # pt = axleg.scatter(0,0, label='Dummy')
    # plt.legend([pt], ["Test"])

    # import pdb; pdb.set_trace()

    if savefile:
        plt.savefig(savefile)

    return xlim, ylim
Exemple #10
0
def plotMultiPane(dim_pairs,
                  star_pars,
                  groups,
                  origins=None,
                  save_file='dummy.pdf',
                  title=None):
    """
    Flexible function that plots many 2D slices through data and fits

    Takes as input a list of dimension pairs, stellar data and fitted
    groups, and will plot each dimension pair in a different pane.

    TODO: Maybe add functionality to control the data plotted in each pane

    Parameters
    ----------
    dim_pairs: a list of dimension pairs e.g.
        [(0,1), (3,4), (0,3), (1,4), (2,5)]
        ['xz', 'uv', 'zw']
        ['XY', 'UV', 'XU', 'YV', 'ZW']
    star_pars: either
        dicitonary of stellar data with keys 'xyzuvw' and 'xyzuvw_cov'
            or
        string filename to saved data
    groups: either
        a single synthesiser.Group object,
        a list or array of synthesiser.Group objects,
            or
        string filename to data saved as '.npy' file
    save_file: string, name (and path) of saved plot figure

    Returns
    -------
    (nothing)
    """

    # Tidying up inputs
    if type(star_pars) is str:
        star_pars = dt.loadXYZUVW(star_pars)
    if type(groups) is str:
        groups = np.load(groups)
        # handle case where groups is a single stored object
        if len(groups.shape) == 0:
            groups = groups.item()
    # ensure groups is iterable
    try:
        len(groups)
    except:  # groups is a single group instance
        groups = [groups]

    if origins:
        try:
            len(origins)
        except:  # origins is a single group instance
            origins = [origins]

    # setting up plot dimensions
    npanes = len(dim_pairs)
    rows = int(np.sqrt(npanes))  #plots are never taller than wide
    cols = (npanes + rows - 1) // rows  # get enough cols
    ax_h = 5
    ax_w = 5
    f, axs = plt.subplots(rows, cols)
    f.set_size_inches(ax_w * cols, ax_h * rows)

    # drawing each axes
    for i, (dim1, dim2) in enumerate(dim_pairs):
        plotPane(dim1,
                 dim2,
                 axs.flatten()[i],
                 groups=groups,
                 origins=origins,
                 star_pars=star_pars,
                 star_orbits=False,
                 group_then=True,
                 group_now=True,
                 group_orbit=True,
                 annotate=False)

    if title:
        f.suptitle(title)
    if save_file:
        f.savefig(save_file, format='pdf')
Exemple #11
0
def plotPane(dim1=0,
             dim2=1,
             ax=None,
             groups=(),
             star_pars=None,
             origin_star_pars=None,
             star_orbits=False,
             origins=None,
             group_then=False,
             group_now=False,
             group_orbit=False,
             annotate=False,
             membership=None,
             true_memb=None,
             savefile='',
             with_bg=False,
             markers=None,
             group_bg=False,
             marker_labels=None,
             color_labels=None,
             marker_style=None,
             marker_legend=None,
             color_legend=None,
             star_pars_label=None,
             origin_star_pars_label=None,
             range_1=None,
             range_2=None,
             isotropic=False,
             ordering=None,
             no_bg_covs=False):
    """
    Plots a single pane capturing kinematic info in any desired 2D plane

    Uses global constants COLORS and HATCHES to inform consistent colour
    scheme.
    Can use this to plot different panes of one whole figure

    Parameters
    ----------
    dim1: x-axis, can either be integer 0-5 (inclusive) or a letter form
          'xyzuvw' (either case)
    dim2: y-axis, same conditions as dim1
    ax:   the axes object on which to plot (defaults to pyplots currnet axes)
    groups: a list of (or just one) synthesiser.Group objects, corresponding
            to the fit of the origin(s)
    star_pars:  dict object with keys 'xyzuvw' ([nstars,6] array of current
                star means) and 'xyzuvw_cov' ([nstars,6,6] array of current
                star covariance matrices)
    star_orbits: (bool) plot the calculated stellar traceback orbits of
                        central estimate of measurements
    group_then: (bool) plot the group's origin
    group_now:  (bool) plot the group's current day distribution
    group_orbit: (bool) plot the trajectory of the group's mean
    annotate: (bool) add text describing the figure's contents
    with_bg: (bool) treat the last column in Z as members of background, and
            color accordingly
    no_bg_covs: (bool) ignore covariance matrices of stars fitted to background

    Returns
    -------
    (nothing returned)
    """
    labels = 'XYZUVW'
    units = 3 * ['pc'] + 3 * ['km/s']

    if savefile:
        plt.clf()

    # Tidying up inputs
    if ax is None:
        ax = plt.gca()
    if type(dim1) is not int:
        dim1 = labels.index(dim1.upper())
    if type(dim2) is not int:
        dim2 = labels.index(dim2.upper())
    if type(star_pars) is str:
        star_pars = dt.loadXYZUVW(star_pars)
    if type(membership) is str:
        membership = np.load(membership)
    if type(groups) is str:
        groups = dt.loadGroups(groups)
    if marker_style is None:
        marker_style = MARKERS[:]
    # if type(origin_star_pars) is str:
    #     origin_star_pars = dt.loadXYZUVW(origin_star_pars)

    legend_pts = []
    legend_labels = []

    # ensure groups is iterable
    try:
        len(groups)
    except:
        groups = [groups]
    ngroups = len(groups)
    if ordering is None:
        ordering = range(len(marker_style))

    # plot stellar data (positions with errors and optionally traceback
    # orbits back to some ill-defined age
    if star_pars:
        nstars = star_pars['xyzuvw'].shape[0]

        # apply default color and markers, to be overwritten if needed
        pt_colors = np.array(nstars * [COLORS[0]])
        if markers is None:
            markers = np.array(nstars * ['.'])

        # Incorporate fitted membership into colors of the pts
        if membership is not None:
            best_mship = np.argmax(membership[:, :ngroups + with_bg], axis=1)
            pt_colors = np.array(COLORS[:ngroups] +
                                 with_bg * ['xkcd:grey'])[best_mship]
            # Incoporate "True" membership into pt markers
            if true_memb is not None:
                markers = np.array(MARKERS)[np.argmax(true_memb, axis=1)]
                if with_bg:
                    true_bg_mask = np.where(true_memb[:, -1] == 1.)
                    markers[true_bg_mask] = '.'
        all_mark_size = np.array(nstars * [MARK_SIZE])

        # group_bg handles case where background is fitted to by final component
        if with_bg:
            all_mark_size[np.where(
                np.argmax(membership, axis=1) == ngroups -
                group_bg)] = BG_MARK_SIZE

        mns = star_pars['xyzuvw']
        try:
            covs = np.copy(star_pars['xyzuvw_cov'])
            # replace background cov matrices with None so as to avoid plotting
            if with_bg and no_bg_covs:
                print("Discarding background cov-mats")
                # import pdb; pdb.set_trace()
                covs[np.where(
                    np.argmax(membership, axis=1) == ngroups -
                    group_bg)] = None
        except KeyError:
            covs = len(mns) * [None]
            star_pars['xyzuvw_cov'] = covs
        st_count = 0
        for star_mn, star_cov, marker, pt_color, m_size in zip(
                mns, covs, markers, pt_colors, all_mark_size):
            pt = ax.scatter(
                star_mn[dim1],
                star_mn[dim2],
                s=m_size,  #s=MARK_SIZE,
                color=pt_color,
                marker=marker,
                alpha=PT_ALPHA,
                linewidth=0.0,
            )
            # plot uncertainties
            if star_cov is not None:
                plotCovEllipse(
                    star_cov[np.ix_([dim1, dim2], [dim1, dim2])],
                    star_mn[np.ix_([dim1, dim2])],
                    ax=ax,
                    alpha=COV_ALPHA,
                    linewidth='0.1',
                    color=pt_color,
                )
            # plot traceback orbits for as long as oldest group (if known)
            # else, 30 Myr
            if star_orbits and st_count % 3 == 0:
                try:
                    tb_limit = max([g.age for g in groups])
                except:
                    tb_limit = 30
                plotOrbit(star_mn,
                          dim1,
                          dim2,
                          ax,
                          end_age=-tb_limit,
                          color='xkcd:grey')
            st_count += 1
        if star_pars_label:
            # ax.legend(numpoints=1)
            legend_pts.append(pt)
            legend_labels.append(star_pars_label)

        if origin_star_pars is not None:
            for star_mn, marker, pt_color, m_size in\
                    zip(origin_star_pars['xyzuvw'],
                        # origin_star_pars['xyzuvw_cov'],
                        markers, pt_colors, all_mark_size):
                pt = ax.scatter(
                    star_mn[dim1],
                    star_mn[dim2],
                    s=0.5 * m_size,
                    # s=MARK_SIZE,
                    color=pt_color,
                    marker='s',
                    alpha=PT_ALPHA,
                    linewidth=0.0,  #label=origin_star_pars_label,
                )
                # # plot uncertainties
                # if star_cov is not None:
                #     ee.plotCovEllipse(
                #         star_cov[np.ix_([dim1, dim2], [dim1, dim2])],
                #         star_mn[np.ix_([dim1, dim2])],
                #         ax=ax, alpha=0.05, linewidth='0.1',
                #         color=pt_color,
                #         )
            if origin_star_pars_label:
                legend_pts.append(pt)
                legend_labels.append(origin_star_pars_label)

    # plot info for each group (fitted, or true synthetic origin)
    for i, group in enumerate(groups):
        cov_then = group.get_covmatrix()
        mean_then = group.get_mean()
        # plot group initial distribution
        if group_then:
            ax.plot(mean_then[dim1],
                    mean_then[dim2],
                    marker='+',
                    alpha=0.3,
                    color=COLORS[i])
            plotCovEllipse(cov_then[np.ix_([dim1, dim2], [dim1, dim2])],
                           mean_then[np.ix_([dim1, dim2])],
                           with_line=True,
                           ax=ax,
                           alpha=0.3,
                           ls='--',
                           color=COLORS[i])
            if annotate:
                ax.annotate(r'$\mathbf{\mu}_0, \mathbf{\Sigma}_0$',
                            (mean_then[dim1], mean_then[dim2]),
                            color=COLORS[i])

        # plot group current day distribution (should match well with stars)
        if group_now:
            mean_now = torb.trace_cartesian_orbit(mean_then,
                                                  group.get_age(),
                                                  single_age=True)
            cov_now = tf.transform_covmatrix(cov_then,
                                             torb.trace_cartesian_orbit,
                                             mean_then,
                                             args=[group.get_age()])
            ax.plot(mean_now[dim1],
                    mean_now[dim2],
                    marker='+',
                    alpha=0.3,
                    color=COLORS[i])
            plotCovEllipse(
                cov_now[np.ix_([dim1, dim2], [dim1, dim2])],
                mean_now[np.ix_([dim1, dim2])],
                # with_line=True,
                ax=ax,
                alpha=0.4,
                ls='-.',
                ec=COLORS[i],
                fill=False,
                hatch=HATCHES[i],
                color=COLORS[i])
            if annotate:
                ax.annotate(r'$\mathbf{\mu}_c, \mathbf{\Sigma}_c$',
                            (mean_now[dim1], mean_now[dim2]),
                            color=COLORS[i])

        # plot orbit of mean of group
        if group_orbit:
            plotOrbit(mean_now,
                      dim1,
                      dim2,
                      ax,
                      -group.age,
                      group_ix=i,
                      with_arrow=True,
                      annotate=annotate)
    if origins:
        for origin in origins:
            cov_then = origin.generateSphericalCovMatrix()
            mean_then = origin.mean
            # plot origin initial distribution
            ax.plot(mean_then[dim1],
                    mean_then[dim2],
                    marker='+',
                    color='xkcd:grey')
            plotCovEllipse(cov_then[np.ix_([dim1, dim2], [dim1, dim2])],
                           mean_then[np.ix_([dim1, dim2])],
                           with_line=True,
                           ax=ax,
                           alpha=0.1,
                           ls='--',
                           color='xkcd:grey')

    ax.set_xlabel("{} [{}]".format(labels[dim1], units[dim1]))
    ax.set_ylabel("{} [{}]".format(labels[dim2], units[dim2]))

    # NOT QUITE....
    # if marker_legend is not None and color_legend is not None:
    #     x_loc = np.mean(star_pars['xyzuvw'][:,dim1])
    #     y_loc = np.mean(star_pars['xyzuvw'][:,dim2])
    #     for label in marker_legend.keys():
    #         ax.plot(x_loc, y_loc, color=color_legend[label],
    #                 marker=marker_legend[label], alpha=0, label=label)
    #     ax.legend(loc='best')

    # if star_pars_label is not None:
    #     ax.legend(numpoints=1, loc='best')
    # ax.legend(loc='best')

    # if marker_order is not None:
    #     for label_ix, marker_ix in enumerate(marker_order):
    #         axleg.scatter(0,0,color='black',marker=MARKERS[marker_ix],
    #                       label=MARKER_LABELS[label_ix])
    # #
    # if len(legend_pts) > 0:
    #     ax.legend(legend_pts, legend_labels)

    # update fontsize
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
                 ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(FONTSIZE)
    if ax.get_legend() is not None:
        for item in ax.get_legend().get_texts():
            item.set_fontsize(FONTSIZE)

    # ensure we have some handle on the ranges
    # if range_1 is None:
    #     range_1 = ax.get_xlim()
    # if range_2 is None:
    #     range_2 = ax.get_ylim()

    if range_2:
        ax.set_ylim(range_2)

    if isotropic:
        print("Setting isotropic for dims {} and {}".format(dim1, dim2))
        # plt.gca().set_aspect('equal', adjustable='box')
        # import pdb; pdb.set_trace()
        plt.gca().set_aspect('equal', adjustable='datalim')

        # manually calculate what the new xaxis must be...
        figW, figH = ax.get_figure().get_size_inches()
        xmid = (ax.get_xlim()[1] + ax.get_xlim()[0]) * 0.5
        yspan = ax.get_ylim()[1] - ax.get_ylim()[0]
        xspan = figW * yspan / figH

        # check if this increases span
        if xspan > ax.get_xlim()[1] - ax.get_xlim()[0]:
            ax.set_xlim(xmid - 0.5 * xspan, xmid + 0.5 * xspan)
        # if not, need to increase yspan
        else:
            ymid = (ax.get_ylim()[1] + ax.get_ylim()[0]) * 0.5
            xspan = ax.get_xlim()[1] - ax.get_xlim()[0]
            yspan = figH * xspan / figW
            ax.set_ylim(ymid - 0.5 * yspan, ymid + 0.5 * yspan)

        # import pdb; pdb.set_trace()
    elif range_1:
        ax.set_xlim(range_1)

    if color_labels is not None:
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        for i, color_label in enumerate(color_labels):
            ax.plot(1e10, 1e10, color=COLORS[i], label=color_label)
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.legend(loc='best')

    if marker_labels is not None:
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        # import pdb; pdb.set_trace()
        for i, marker_label in enumerate(marker_labels):

            ax.scatter(
                1e10,
                1e10,
                c='black',
                marker=np.array(marker_style)[ordering][i],
                # marker=MARKERS[list(marker_labels).index(marker_label)],
                label=marker_label)
        if with_bg:
            ax.scatter(1e10,
                       1e10,
                       c='xkcd:grey',
                       marker='.',
                       label='Background')
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.legend(loc='best')

    # if marker_legend is not None:
    #     xlim = ax.get_xlim()
    #     ylim = ax.get_ylim()
    #     # import pdb; pdb.set_trace()
    #     for k, v in marker_legend.items():
    #         ax.scatter(1e10, 1e10, c='black',
    #                    marker=v, label=k)
    #     ax.set_xlim(xlim)
    #     ax.set_ylim(ylim)
    #     ax.legend(loc='best')

    if color_legend is not None and marker_legend is not None:
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        # import pdb; pdb.set_trace()
        for label in color_legend.keys():
            ax.scatter(1e10,
                       1e10,
                       c=color_legend[label],
                       marker=marker_legend[label],
                       label=label)
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.legend(loc='best')

    if savefile:
        # set_size(4,2,ax)
        plt.savefig(savefile)
    # import pdb; pdb.set_trace()

    # return ax.get_window_extent(None).width, ax.get_window_extent(None).height

    return ax.get_xlim(), ax.get_ylim()
"""
from __future__ import print_function, division

import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import sys
sys.path.insert(0, '..')
import chronostar.retired2.datatool as dt

MARGIN = 2.0

assoc_name = "bpmg_cand_w_gaia_dr2_astrometry_comb_binars"
xyzuvw_file = "../data/{}_xyzuvw.fits".format(assoc_name)
xyzuvw_dict = dt.loadXYZUVW(xyzuvw_file)

# construct box
kin_max = np.max(xyzuvw_dict['xyzuvw'], axis=0)
kin_min = np.min(xyzuvw_dict['xyzuvw'], axis=0)
span = kin_max - kin_min
upper_boundary = kin_max + MARGIN * span
lower_boundary = kin_min - MARGIN * span

# get gaia stars within box
gaia_xyzuvw_file = "../data/gaia_dr2_mean_xyzuvw.npy"
gaia_xyzuvw = np.load(gaia_xyzuvw_file)
mask = np.where(
    np.all((gaia_xyzuvw < upper_boundary) & (gaia_xyzuvw > lower_boundary),
           axis=1))
print(mask[0].shape)
    # Setting up file names
    synth_fit = 'fed_stars'
    rdir = '../../results/archive/fed_fits/20/gaia/'
    origins_file = rdir + 'origins.npy'
    origin_comp_file = rdir + 'origin_ellip_comp.npy'
    chain_file = rdir + 'final_chain.npy'
    lnprob_file = rdir + 'final_lnprob.npy'
    star_pars_file = rdir + 'xyzuvw_now.fits'
    init_xyzuvw_file = rdir + '../xyzuvw_init_offset.npy'
    perf_xyzuvw_now = rdir + '../perf_xyzuvw.npy'

    # loading in data
    best_comp = SphereComponent.get_best_from_chain(chain_file, lnprob_file)
    origin_comp = EllipComponent.load_raw_components(origin_comp_file)[0]
    init_xyzuvw = np.load(init_xyzuvw_file)
    star_pars = dt.loadXYZUVW(star_pars_file)
    perf_mean_now = np.load(perf_xyzuvw_now)

    original_origin = dt.loadGroups(origins_file)[0]

    # assigning useful shorthands
    mns = star_pars['xyzuvw']
    covs = star_pars['xyzuvw_cov']

    fed_xranges, fed_yranges = calcRanges(
        {'xyzuvw': np.vstack((star_pars['xyzuvw'], init_xyzuvw))},
        sep_axes=True,
    )

    labels = 'XYZUVW'
    units = 3 * ['[pc]'] + 3 * ['[km/s]']
Exemple #14
0
logging.info("  with error fraction {}".format(ERROR))
logging.info("  and background density {}".format(BG_DENS))
# Set a current-day location around which synth stars will end up
mean_now = np.array([50., -100., -0., -10., -20., -5.])

logging.info("Mean (now):\n{}".format(mean_now))
logging.info("Extra pars:\n{}".format(extra_pars))
logging.info("Offsets:\n{}".format(offsets))

try:
    #all_xyzuvw_now_perf = np.load(xyzuvw_perf_file)
    np.load(xyzuvw_perf_file)
    #origins = dt.loadGroups(groups_savefile)
    dt.loadGroups(groups_savefile)
    #star_pars = dt.loadXYZUVW(xyzuvw_conv_savefile)
    dt.loadXYZUVW(xyzuvw_conv_savefile)
    logging.info("Synth data exists! .....")
    print("Synth data exists")
    raise UserWarning
except IOError:
    all_xyzuvw_init = np.zeros((0, 6))
    all_xyzuvw_now_perf = np.zeros((0, 6))
    origins = []
    for i in range(ngroups):
        logging.info(" generating from group {}".format(i))
        # MANUALLY SEPARATE CURRENT DAY DISTROS IN DIMENSION X
        mean_now_w_offset = mean_now.copy()
        # mean_now_w_offset[0] += i * 50
        mean_now_w_offset += offsets[i]

        mean_then = torb.trace_cartesian_orbit(mean_now_w_offset,
Exemple #15
0
def plotPaneWithHists(dim1, dim2, fignum=None, groups=[], weights=None,
                      star_pars=None,
                      star_orbits=False,
                      group_then=False, group_now=False, group_orbit=False,
                      annotate=False, bg_hists=None, membership=None,
                      true_memb=None, savefile='', with_bg=False,
                      range_1=None, range_2=None, residual=False,
                      markers=None, group_bg=False, isotropic=False,
                      color_labels=[], marker_labels=[], marker_order=[],
                      ordering=None, no_bg_covs=False):
    """
    Plot a 2D projection of data and fit along with flanking 1D projections.

    Uses global constants COLORS and HATCHES to inform consistent colour
    scheme.
    Can use this to plot different panes of one whole figure

    TODO: Incorporate Z
    TODO: incoporate background histogram

    Parameters
    ----------
    dim1: x-axis, can either be integer 0-5 (inclusive) or a letter form
          'xyzuvw' (either case)
    dim2: y-axis, same conditions as dim1
    fignum: figure number in which to create the plot
    groups: a list of (or just one) synthesiser.Group objects, corresponding
            to the fit of the origin(s)
    star_pars:  dict object with keys 'xyzuvw' ([nstars,6] array of current
                star means) and 'xyzuvw_cov' ([nstars,6,6] array of current
                star covariance matrices)
    star_orbits: (bool) plot the calculated stellar traceback orbits of
                        central estimate of measurements
    group_then: (bool) plot the group's origin
    group_now:  (bool) plot the group's current day distribution
    group_orbit: (bool) plot the trajectory of the group's mean
    annotate: (bool) add text describing the figure's contents
    with_bg: (bool) treat the final column of Z as background memberships
             and color accordingly

    Returns
    -------
    (nothing returned)
    """
    labels = 'XYZUVW'
    axes_units = 3*['pc'] + 3*['km/s']
    if type(membership) is str:
        membership = np.load(membership)
    if type(star_pars) is str:
        star_pars = dt.loadXYZUVW(star_pars)
    # if ordering:
    #     membership = membership[:,ordering]

    # TODO: clarify what exactly you're trying to do here
    if weights is None and len(groups) > 0:
        if len(groups) == 1 and not with_bg:
            weights = np.array([len(star_pars['xyzuvw'])])
        elif membership is not None:
            weights = membership.sum(axis=0)
        else:
            weights = np.ones(len(groups)) / len(groups)

    if type(dim1) is not int:
        dim1 = labels.index(dim1.upper())
    if type(dim2) is not int:
        dim2 = labels.index(dim2.upper())
    if type(groups) is str:
        groups = np.load(groups)
        if len(groups.shape) == 0:
            groups = np.array(groups.item())
    if type(bg_hists) is str:
        bg_hists = np.load(bg_hists)

    # Set global plt tick params???
    tick_params = {'direction':'in', 'top':True, 'right':True}
    plt.tick_params(**tick_params)

    # Set up plot
    fig_width = 5 #inch
    fig_height = 5 #inch
    fig = plt.figure(fignum, figsize=(fig_width,fig_height))
    plt.clf()
    # gs = gridspec.GridSpec(4, 4)
    gs = gridspec.GridSpec(4, 4)

    # Set up some global plot features
    # fig.set_tight_layout(tight=True)
    plt.figure()

    # Plot central pane
    axcen = plt.subplot(gs[1:, :-1])
    xlim, ylim = plotPane(
        dim1, dim2, ax=axcen, groups=groups, star_pars=star_pars,
        star_orbits=star_orbits, group_then=group_then,
        group_now=group_now, group_orbit=group_orbit, annotate=annotate,
        membership=membership, true_memb=true_memb, with_bg=with_bg,
        markers=markers, group_bg=group_bg, isotropic=isotropic,
        range_1=range_1, range_2=range_2, marker_labels=marker_labels,
        color_labels=color_labels, ordering=ordering, no_bg_covs=no_bg_covs)
    plt.tick_params(**tick_params)
    # if range_1:
    #     plt.xlim(range_1)
    # if range_2:
    #     plt.ylim(range_2)
    # plt.grid(gridsepc_kw={'wspace': 0, 'hspace': 0})
    # plt.sharex(True)

    # Plot flanking 1D projections
    # xlim = axcen.get_xlim()
    axtop = plt.subplot(gs[0, :-1])
    axtop.set_xlim(xlim)
    axtop.set_xticklabels([])
    plot1DProjection(dim1, star_pars, groups, weights, ax=axtop,
                     bg_hists=bg_hists, with_bg=with_bg, membership=membership,
                     residual=residual, x_range=xlim)
    axtop.set_ylabel('Stars per {}'.format(axes_units[dim1]))
    plt.tick_params(**tick_params)
    # axcen.set_tick_params(direction='in', top=True, right=True)

    # ylim = axcen.get_ylim()
    axright = plt.subplot(gs[1:, -1])
    axright.set_ylim(ylim)
    axright.set_yticklabels([])
    plot1DProjection(dim2, star_pars, groups, weights, ax=axright,
                     bg_hists=bg_hists, horizontal=True, with_bg=with_bg,
                     membership=membership, residual=residual,
                     x_range=ylim)
    axright.set_xlabel('Stars per {}'.format(axes_units[dim2]))
    # axcen.set_tick_params(direction='in', top=True, right=True)
    plt.tick_params(**tick_params)
    # plt.tight_layout(pad=0.7)

    axleg = plt.subplot(gs[0,-1])
    for spine in axleg.spines.values():
        spine.set_visible(False)
    axleg.tick_params(labelbottom='off', labelleft='off', bottom='off',
                   left='off')
    # import pdb; pdb.set_trace()

    if False:
        for label_ix, marker_ix in enumerate(marker_order):
            axleg.scatter(0,0,color='black',marker=MARKERS[marker_ix],
                          label=MARKER_LABELS[label_ix])
        for i, color_label in enumerate(color_labels):
            axleg.plot(0,0,color=COLORS[i],label=color_label)
        axleg.legend(loc='best', framealpha=1.0)

    # for i, marker_label in enumerate(marker_labels):
    #     axleg.scatter(0,0,color='black',marker=MARKERS[i],label=marker_label)
    # pt = axleg.scatter(0,0, label='Dummy')
    # plt.legend([pt], ["Test"])

    # import pdb; pdb.set_trace()

    if savefile:
        plt.savefig(savefile)

    return xlim, ylim
Exemple #16
0
if os.path.isfile(rdir + 'bg_hists.npy'):
    bg_hists = np.load(rdir + 'bg_hists.npy')
else:
    bg_hists = None


is_inc_fit = os.path.isdir(rdir + '1/')
is_synth_fit = os.path.isdir(rdir + 'synth_data/')

# First, if stars are synthetic, plot true groups
true_memb = None
if is_synth_fit:
    origins = np.load(rdir + 'synth_data/origins.npy')
    true_memb = getZfromOrigins(origins, star_pars_file)
    with_bg = len(origins) < true_memb.shape[1]
    assert true_memb.shape[0] == dt.loadXYZUVW(star_pars_file)['xyzuvw'].shape[0]
    if len(origins.shape) == 0:
        origins = np.array(origins.item())
    weights = np.array([origin.nstars for origin in origins])
    for dim1, dim2 in ('xy', 'uv', 'xu', 'yv', 'zw', 'xw'):
        plt.clf()
        fp.plotPaneWithHists(dim1, dim2, star_pars=star_pars_file,
                             groups=origins, weights=weights,
                             group_now=True, with_bg=with_bg,
                             no_bg_covs=with_bg,
                             )
        plt.savefig(rdir + 'pre_plot_{}{}.pdf'.format(dim1,dim2))

# Now choose if handling incremental fit or plain fit
if not is_inc_fit:
    plotEveryIter(rdir, star_pars_file)
            labels=axis_labels,
            # reverse=True,
            label_kwargs={'fontsize':'xx-large'},
            max_n_ticks=4,
        )
        print("Applying tick parameters")
        for ax in fig.axes:
            ax.tick_params(direction='in', labelsize='x-large', top=True,
                           right=True)
        print("... saving")
        plt.savefig(plot_name)

if PLOT_BPMG_REAL:
    for iteration in ['5B']: #, '6C']:
        star_pars_file = '../../data/beta_Pictoris_with_gaia_small_xyzuvw.fits'
        star_pars = dt.loadXYZUVW(star_pars_file)
        fit_name = 'bpmg_and_nearby'
        rdir = '../../results/em_fit/beta_Pictoris_wgs_inv2_{}_res/'.format(iteration)

        memb_file = rdir + 'final_membership.npy'
        groups_file = rdir + 'final_groups.npy'

        z = np.load(memb_file)
        groups = dt.loadGroups(groups_file)

        # Assign markers based on BANYAN membership
        gt_sp = dt.loadDictFromTable('../../data/banyan_with_gaia_near_bpmg_xyzuvw.fits')
        banyan_membership = len(star_pars['xyzuvw']) * ['N/A']
        for i in range(len(star_pars['xyzuvw'])):
            master_table_ix = np.where(gt_sp['table']['source_id']==star_pars['gaia_ids'][i])
            banyan_membership[i] = gt_sp['table']['Moving group'][master_table_ix[0][0]]
    using_mpi = False
    pool=None

if using_mpi:
    if not pool.is_master():
        print("One thread is going to sleep")
        # Wait for instructions from the master process.
        pool.wait()
        sys.exit(0)
print("Only one thread is master")

#logging.info(path_msg)

print("Master should be working in the directory:\n{}".format(rdir))

star_pars = dt.loadXYZUVW(xyzuvw_file, assoc_name=ass_name)

# GET BACKGROUND LOG OVERLAP DENSITIES:
logging.info("Acquiring background overlaps")
try:
    print("In try")
    logging.info("Calculating background overlaps")
    logging.info(" -- this step employs scipy's kernel density estimator")
    logging.info(" -- so could take a few minutes...")
    bg_ln_ols = np.load(bg_ln_ols_file)
    print("could load")
    assert len(bg_ln_ols) == len(star_pars['xyzuvw'])
    logging.info("Loaded bg_ln_ols from file")
except (IOError, AssertionError):
    print("in except")
    bg_ln_ols = dt.getKernelDensities(gaia_xyzuvw_file, star_pars['xyzuvw'])
logging.info("  with error fraction {}".format(ERROR))
logging.info("  and background density {}".format(BG_DENS))
# Set a current-day location around which synth stars will end up
mean_now = np.array([50., -100., -0., -10., -20., -5.])

logging.info("Mean (now):\n{}".format(mean_now))
logging.info("Extra pars:\n{}".format(extra_pars))
logging.info("Offsets:\n{}".format(offsets))

try:
    #all_xyzuvw_now_perf = np.load(xyzuvw_perf_file)
    np.load(xyzuvw_perf_file)
    #origins = dt.loadGroups(groups_savefile)
    dt.loadGroups(groups_savefile)
    #star_pars = dt.loadXYZUVW(xyzuvw_conv_savefile)
    dt.loadXYZUVW(xyzuvw_conv_savefile)
    logging.info("Synth data exists! .....")
    print("Synth data exists")
    raise UserWarning
except IOError:
    all_xyzuvw_init = np.zeros((0,6))
    all_xyzuvw_now_perf = np.zeros((0,6))
    origins = []
    for i in range(ngroups):
        logging.info(" generating from group {}".format(i))
        # MANUALLY SEPARATE CURRENT DAY DISTROS IN DIMENSION X
        mean_now_w_offset = mean_now.copy()
        # mean_now_w_offset[0] += i * 50
        mean_now_w_offset += offsets[i]
    
        mean_then = torb.trace_cartesian_orbit(mean_now_w_offset, -extra_pars[i, -2],
Exemple #20
0
def plotPane(dim1=0, dim2=1, ax=None, groups=(), star_pars=None,
             origin_star_pars=None,
             star_orbits=False, origins=None,
             group_then=False, group_now=False, group_orbit=False,
             annotate=False, membership=None, true_memb=None,
             savefile='', with_bg=False, markers=None, group_bg=False,
             marker_labels=None, color_labels=None,
             marker_style=None,
             marker_legend=None, color_legend=None,
             star_pars_label=None, origin_star_pars_label=None,
             range_1=None, range_2=None, isotropic=False,
             ordering=None, no_bg_covs=False):
    """
    Plots a single pane capturing kinematic info in any desired 2D plane

    Uses global constants COLORS and HATCHES to inform consistent colour
    scheme.
    Can use this to plot different panes of one whole figure

    Parameters
    ----------
    dim1: x-axis, can either be integer 0-5 (inclusive) or a letter form
          'xyzuvw' (either case)
    dim2: y-axis, same conditions as dim1
    ax:   the axes object on which to plot (defaults to pyplots currnet axes)
    groups: a list of (or just one) synthesiser.Group objects, corresponding
            to the fit of the origin(s)
    star_pars:  dict object with keys 'xyzuvw' ([nstars,6] array of current
                star means) and 'xyzuvw_cov' ([nstars,6,6] array of current
                star covariance matrices)
    star_orbits: (bool) plot the calculated stellar traceback orbits of
                        central estimate of measurements
    group_then: (bool) plot the group's origin
    group_now:  (bool) plot the group's current day distribution
    group_orbit: (bool) plot the trajectory of the group's mean
    annotate: (bool) add text describing the figure's contents
    with_bg: (bool) treat the last column in Z as members of background, and
            color accordingly
    no_bg_covs: (bool) ignore covariance matrices of stars fitted to background

    Returns
    -------
    (nothing returned)
    """
    labels = 'XYZUVW'
    units = 3 * ['pc'] + 3 * ['km/s']

    if savefile:
        plt.clf()

    # Tidying up inputs
    if ax is None:
        ax = plt.gca()
    if type(dim1) is not int:
        dim1 = labels.index(dim1.upper())
    if type(dim2) is not int:
        dim2 = labels.index(dim2.upper())
    if type(star_pars) is str:
        star_pars = dt.loadXYZUVW(star_pars)
    if type(membership) is str:
        membership = np.load(membership)
    if type(groups) is str:
        groups = dt.loadGroups(groups)
    if marker_style is None:
        marker_style = MARKERS[:]
    # if type(origin_star_pars) is str:
    #     origin_star_pars = dt.loadXYZUVW(origin_star_pars)

    legend_pts = []
    legend_labels = []

    # ensure groups is iterable
    try:
        len(groups)
    except:
        groups = [groups]
    ngroups = len(groups)
    if ordering is None:
        ordering = range(len(marker_style))

    # plot stellar data (positions with errors and optionally traceback
    # orbits back to some ill-defined age
    if star_pars:
        nstars = star_pars['xyzuvw'].shape[0]

        # apply default color and markers, to be overwritten if needed
        pt_colors = np.array(nstars * [COLORS[0]])
        if markers is None:
            markers = np.array(nstars * ['.'])

        # Incorporate fitted membership into colors of the pts
        if membership is not None:
            best_mship = np.argmax(membership[:,:ngroups+with_bg], axis=1)
            pt_colors = np.array(COLORS[:ngroups] + with_bg*['xkcd:grey'])[best_mship]
            # Incoporate "True" membership into pt markers
            if true_memb is not None:
                markers = np.array(MARKERS)[np.argmax(true_memb,
                                                      axis=1)]
                if with_bg:
                    true_bg_mask = np.where(true_memb[:,-1] == 1.)
                    markers[true_bg_mask] = '.'
        all_mark_size = np.array(nstars * [MARK_SIZE])

        # group_bg handles case where background is fitted to by final component
        if with_bg:
            all_mark_size[np.where(np.argmax(membership, axis=1) == ngroups-group_bg)] = BG_MARK_SIZE

        mns = star_pars['xyzuvw']
        try:
            covs = np.copy(star_pars['xyzuvw_cov'])
            # replace background cov matrices with None so as to avoid plotting
            if with_bg and no_bg_covs:
                print("Discarding background cov-mats")
                # import pdb; pdb.set_trace()
                covs[np.where(np.argmax(membership, axis=1) == ngroups-group_bg)] = None
        except KeyError:
            covs = len(mns) * [None]
            star_pars['xyzuvw_cov'] = covs
        st_count = 0
        for star_mn, star_cov, marker, pt_color, m_size in zip(mns, covs, markers, pt_colors,
                                                               all_mark_size):
            pt = ax.scatter(star_mn[dim1], star_mn[dim2], s=m_size, #s=MARK_SIZE,
                            color=pt_color, marker=marker, alpha=PT_ALPHA,
                            linewidth=0.0,
                            )
            # plot uncertainties
            if star_cov is not None:
                plotCovEllipse(star_cov[np.ix_([dim1, dim2], [dim1, dim2])],
                               star_mn[np.ix_([dim1, dim2])],
                               ax=ax, alpha=COV_ALPHA, linewidth='0.1',
                               color=pt_color,
                               )
            # plot traceback orbits for as long as oldest group (if known)
            # else, 30 Myr
            if star_orbits and st_count%3==0:
                try:
                    tb_limit = max([g.age for g in groups])
                except:
                    tb_limit = 30
                plotOrbit(star_mn, dim1, dim2, ax, end_age=-tb_limit,
                          color='xkcd:grey')
            st_count += 1
        if star_pars_label:
            # ax.legend(numpoints=1)
            legend_pts.append(pt)
            legend_labels.append(star_pars_label)

        if origin_star_pars is not None:
            for star_mn, marker, pt_color, m_size in\
                    zip(origin_star_pars['xyzuvw'],
                        # origin_star_pars['xyzuvw_cov'],
                        markers, pt_colors, all_mark_size):
                pt = ax.scatter(star_mn[dim1], star_mn[dim2], s=0.5*m_size,
                           # s=MARK_SIZE,
                           color=pt_color, marker='s', alpha=PT_ALPHA,
                           linewidth=0.0, #label=origin_star_pars_label,
                           )
                # # plot uncertainties
                # if star_cov is not None:
                #     ee.plotCovEllipse(
                #         star_cov[np.ix_([dim1, dim2], [dim1, dim2])],
                #         star_mn[np.ix_([dim1, dim2])],
                #         ax=ax, alpha=0.05, linewidth='0.1',
                #         color=pt_color,
                #         )
            if origin_star_pars_label:
                legend_pts.append(pt)
                legend_labels.append(origin_star_pars_label)


    # plot info for each group (fitted, or true synthetic origin)
    for i, group in enumerate(groups):
        cov_then = group.generateSphericalCovMatrix()
        mean_then = group.mean
        # plot group initial distribution
        if group_then:
            ax.plot(mean_then[dim1], mean_then[dim2], marker='+', alpha=0.3,
                    color=COLORS[i])
            plotCovEllipse(cov_then[np.ix_([dim1, dim2], [dim1, dim2])],
                           mean_then[np.ix_([dim1,dim2])],
                           with_line=True,
                           ax=ax, alpha=0.3, ls='--',
                           color=COLORS[i])
            if annotate:
                ax.annotate(r'$\mathbf{\mu}_0, \mathbf{\Sigma}_0$',
                            (mean_then[dim1],
                             mean_then[dim2]),
                             color=COLORS[i])

        # plot group current day distribution (should match well with stars)
        if group_now:
            mean_now = torb.trace_cartesian_orbit(mean_then, group.age,
                                                  single_age=True)
            cov_now = tf.transform_covmatrix(cov_then, torb.trace_cartesian_orbit,
                                             mean_then, args=[group.age])
            ax.plot(mean_now[dim1], mean_now[dim2], marker='+', alpha=0.3,
                   color=COLORS[i])
            plotCovEllipse(cov_now[np.ix_([dim1, dim2], [dim1, dim2])],
                           mean_now[np.ix_([dim1,dim2])],
                           # with_line=True,
                           ax=ax, alpha=0.4, ls='-.',
                           ec=COLORS[i], fill=False, hatch=HATCHES[i],
                           color=COLORS[i])
            if annotate:
                ax.annotate(r'$\mathbf{\mu}_c, \mathbf{\Sigma}_c$',
                            (mean_now[dim1],mean_now[dim2]),
                            color=COLORS[i])

        # plot orbit of mean of group
        if group_orbit:
            plotOrbit(mean_now, dim1, dim2, ax, -group.age, group_ix=i,
                      with_arrow=True, annotate=annotate)
    if origins:
        for origin in origins:
            cov_then = origin.generateSphericalCovMatrix()
            mean_then = origin.mean
            # plot origin initial distribution
            ax.plot(mean_then[dim1], mean_then[dim2], marker='+',
                    color='xkcd:grey')
            plotCovEllipse(
                cov_then[np.ix_([dim1, dim2], [dim1, dim2])],
                mean_then[np.ix_([dim1, dim2])],
                with_line=True,
                ax=ax, alpha=0.1, ls='--',
                color='xkcd:grey')

    ax.set_xlabel("{} [{}]".format(labels[dim1], units[dim1]))
    ax.set_ylabel("{} [{}]".format(labels[dim2], units[dim2]))

    # NOT QUITE....
    # if marker_legend is not None and color_legend is not None:
    #     x_loc = np.mean(star_pars['xyzuvw'][:,dim1])
    #     y_loc = np.mean(star_pars['xyzuvw'][:,dim2])
    #     for label in marker_legend.keys():
    #         ax.plot(x_loc, y_loc, color=color_legend[label],
    #                 marker=marker_legend[label], alpha=0, label=label)
    #     ax.legend(loc='best')

    # if star_pars_label is not None:
    #     ax.legend(numpoints=1, loc='best')
        # ax.legend(loc='best')

    # if marker_order is not None:
    #     for label_ix, marker_ix in enumerate(marker_order):
    #         axleg.scatter(0,0,color='black',marker=MARKERS[marker_ix],
    #                       label=MARKER_LABELS[label_ix])
    # #
    # if len(legend_pts) > 0:
    #     ax.legend(legend_pts, legend_labels)

    # update fontsize
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
                 ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(FONTSIZE)
    if ax.get_legend() is not None:
        for item in ax.get_legend().get_texts():
            item.set_fontsize(FONTSIZE)

    # ensure we have some handle on the ranges
    # if range_1 is None:
    #     range_1 = ax.get_xlim()
    # if range_2 is None:
    #     range_2 = ax.get_ylim()

    if range_2:
        ax.set_ylim(range_2)

    if isotropic:
        print("Setting isotropic for dims {} and {}".format(dim1, dim2))
        # plt.gca().set_aspect('equal', adjustable='box')
        # import pdb; pdb.set_trace()
        plt.gca().set_aspect('equal', adjustable='datalim')

        # manually calculate what the new xaxis must be...
        figW, figH = ax.get_figure().get_size_inches()
        xmid = (ax.get_xlim()[1] + ax.get_xlim()[0]) * 0.5
        yspan = ax.get_ylim()[1] - ax.get_ylim()[0]
        xspan = figW * yspan / figH

        # check if this increases span
        if xspan > ax.get_xlim()[1] - ax.get_xlim()[0]:
            ax.set_xlim(xmid - 0.5 * xspan, xmid + 0.5 * xspan)
        # if not, need to increase yspan
        else:
            ymid = (ax.get_ylim()[1] + ax.get_ylim()[0]) * 0.5
            xspan = ax.get_xlim()[1] - ax.get_xlim()[0]
            yspan = figH * xspan / figW
            ax.set_ylim(ymid - 0.5*yspan, ymid + 0.5*yspan)

        # import pdb; pdb.set_trace()
    elif range_1:
        ax.set_xlim(range_1)

    if color_labels is not None:
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        for i, color_label in enumerate(color_labels):
            ax.plot(1e10, 1e10, color=COLORS[i], label=color_label)
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.legend(loc='best')

    if marker_labels is not None:
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        # import pdb; pdb.set_trace()
        for i, marker_label in enumerate(marker_labels):

            ax.scatter(1e10, 1e10, c='black',
                       marker=np.array(marker_style)[ordering][i],
                       # marker=MARKERS[list(marker_labels).index(marker_label)],
                       label=marker_label)
        if with_bg:
            ax.scatter(1e10, 1e10, c='xkcd:grey',
                       marker='.', label='Background')
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.legend(loc='best')

    # if marker_legend is not None:
    #     xlim = ax.get_xlim()
    #     ylim = ax.get_ylim()
    #     # import pdb; pdb.set_trace()
    #     for k, v in marker_legend.items():
    #         ax.scatter(1e10, 1e10, c='black',
    #                    marker=v, label=k)
    #     ax.set_xlim(xlim)
    #     ax.set_ylim(ylim)
    #     ax.legend(loc='best')

    if color_legend is not None and marker_legend is not None:
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        # import pdb; pdb.set_trace()
        for label in color_legend.keys():
            ax.scatter(1e10, 1e10, c=color_legend[label],
            marker=marker_legend[label], label=label)
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.legend(loc='best')

    if savefile:
        # set_size(4,2,ax)
        plt.savefig(savefile)
    # import pdb; pdb.set_trace()

    # return ax.get_window_extent(None).width, ax.get_window_extent(None).height

    return ax.get_xlim(), ax.get_ylim()