Beispiel #1
0
def getZfromOrigins(origins, star_pars):
    if type(origins) is str:
        origins = dt.loadGroups(origins)
    if type(star_pars) is str:
        star_pars = dt.loadXYZUVW(star_pars)
    nstars = star_pars['xyzuvw'].shape[0]
    ngroups = len(origins)
    nassoc_stars = np.sum([o.nstars for o in origins])
    using_bg = nstars != nassoc_stars
    z = np.zeros((nstars, ngroups + using_bg))
    stars_so_far = 0
    # set associaiton members memberships to 1
    for i, o in enumerate(origins):
        z[stars_so_far:stars_so_far+o.nstars, i] = 1.
        stars_so_far += o.nstars
    # set remaining stars as members of background
    if using_bg:
        z[stars_so_far:,-1] = 1.
    return z
Beispiel #2
0
while ncomps < MAX_COMP:
    if ncomps >= MAX_COMP:
        logging.info("++++++++++++++++++++++++++++++++++++++++++++++++++")
        logging.info("+++++++++++   REACHED MAX COMP LIMIT   +++++++++++")
        logging.info("++++++++++++++++++++++++++++++++++++++++++++++++++")
    # handle special case of one component
    if ncomps == 1:
        logging.info("******************************************")
        logging.info("*********  FITTING 1 COMPONENT  **********")
        logging.info("******************************************")
        run_dir = rdir + '{}/'.format(ncomps)
        mkpath(run_dir)

        try:
            new_groups = dt.loadGroups(run_dir + 'final/final_groups.npy')
            new_meds = np.load(run_dir + 'final/final_med_errs.npy')
            new_z = np.load(run_dir + 'final/final_membership.npy')
            logging.info("Loaded from previous run")
        except IOError:
            new_groups, new_meds, new_z = \
                em.fit_many_comps(star_pars, ncomps, rdir=run_dir, pool=pool,
                                  bg_ln_ols=bg_ln_ols) # kill component here
            new_groups = np.array(new_groups)

        new_lnlike = em.get_overall_lnlikelihood(star_pars, new_groups,
                                                 bg_ln_ols=bg_ln_ols)
        new_lnpost = em.get_overall_lnlikelihood(star_pars, new_groups,
                                                 bg_ln_ols=bg_ln_ols,
                                                 inc_posterior=True)
        new_BIC = em.calc_bic(star_pars, ncomps, new_lnlike)
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.insert(0, '..')
import chronostar.traceorbit as torb
import chronostar.retired2.datatool as dt

rdir = "../results/synth_fit/50_2_1_50/"
# rdir = "../results/synth_fit/30_5_2_100/"

xyzuvw_init_file = rdir + "xyzuvw_init.npy"

xyzuvw_init = np.load(xyzuvw_init_file)
origin = np.load(rdir + 'origins.npy').item()
origin = dt.loadGroups(rdir + 'origins.npy')
max_age = origin.age
ntimes = int(max_age) + 1
#ntimes = 3
times = np.linspace(1e-5, max_age, ntimes)

traceforward = torb.trace_many_cartesian_orbit(xyzuvw_init, times, False)
nstars = xyzuvw_init.shape[0]

def plot_subplot(traceforward, t_ix, dim1, dim2, ax):
    flat_tf = traceforward.reshape(-1,6)
#    mins = np.min(flat_tf, axis=0)
#    maxs = np.max(flat_tf, axis=0)
    labels = ['X [pc]', 'Y [pc]', 'Z [pc]',
              'U [km/s]', 'V [km/s]', 'W [km/s]']
fitted_z_stem = 'final_membership.npy'
med_errs_stem = 'final_med_errs.npy'
use_bg = False

for fit, order, suffix in zip(fits, orders, suffixs):
    print(fit)
    rdir = '../results/em_fit/{}'.format(fit) + suffix + '/'

    star_pars_file = '../data/{}_xyzuvw.fits'.format(fit)
    save_file_name = '../results/tables/{}_table.tex'.format(fit)

    origin_file = rdir + origin_stem
    fitted_z_file = rdir + fitted_z_stem
    med_errs_file = rdir + med_errs_stem

    origins = dt.loadGroups(origin_file)
    true_z = dt.getZfromOrigins(origins, star_pars_file)
    true_nstars = np.sum(true_z, axis=0)
    fitted_z = np.load(fitted_z_file)
    fitted_nstars = np.sum(fitted_z, axis=0)
    med_errs = np.load(med_errs_file)
    med_errs[:,6:8] = np.exp(med_errs[:,6:8])

    # manually work out ordering of groups
    origins = origins[(order,)]
    if use_bg:
        true_nstars = true_nstars[(order + [-1],)]
    else:
        true_nstars = true_nstars[(order,)]
    # manually inform on presence of background or not
def loadFinalResults(fdir):
    fgroups = dt.loadGroups(fdir + 'final_groups.npy')
    fmembs = np.load(fdir + 'final_membership.npy')
    fmed_errs = np.load(fdir + 'final_med_errs.npy')
    return fgroups, fmembs, fmed_errs
Beispiel #6
0
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.insert(0, '..')
import chronostar.traceorbit as torb
import chronostar.retired2.datatool as dt

rdir = "../results/synth_fit/50_2_1_50/"
# rdir = "../results/synth_fit/30_5_2_100/"

xyzuvw_init_file = rdir + "xyzuvw_init.npy"

xyzuvw_init = np.load(xyzuvw_init_file)
origin = np.load(rdir + 'origins.npy').item()
origin = dt.loadGroups(rdir + 'origins.npy')
max_age = origin.age
ntimes = int(max_age) + 1
#ntimes = 3
times = np.linspace(1e-5, max_age, ntimes)

traceforward = torb.trace_many_cartesian_orbit(xyzuvw_init, times, False)
nstars = xyzuvw_init.shape[0]


def plot_subplot(traceforward, t_ix, dim1, dim2, ax):
    flat_tf = traceforward.reshape(-1, 6)
    #    mins = np.min(flat_tf, axis=0)
    #    maxs = np.max(flat_tf, axis=0)
    labels = ['X [pc]', 'Y [pc]', 'Z [pc]', 'U [km/s]', 'V [km/s]', 'W [km/s]']
Beispiel #7
0
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
import sys
sys.path.insert(0, '..')
import chronostar.retired2.datatool as dt
import chronostar.expectmax as em
from chronostar._overlap import get_lnoverlaps

bpmg_group_file = 'final_groups.npy'
gaia_sep_hist_file = 'bg_hists.npy'
gaia_6d_hist_file = 'gaia_6d_hist.npy'
gaia_xyzuvw = np.load('../data/gaia_dr2_mean_xyzuvw.npy')

bpmg_group = dt.loadGroups(bpmg_group_file)[0]
ref_mean = bpmg_group.mean[:]
# ref_mean[-1] += 4. #SANITY CHECK, perform same analysis for W offset by 4
print("1 comp. fit to BPMG has mean {}".format(ref_mean))
print("We will calculate all densities at this point...")
gaia_sep_hist = np.load(gaia_sep_hist_file)
try:
    gaia_6d_hist = np.load(gaia_6d_hist_file)
    print("Loaded histogram with bins {}".format(gaia_6d_hist[0].shape[0]))
except IOError:
    bins = 20
    print("Generating histogram with bins {}".format(bins))
    gaia_xyzuvw = np.load('../data/gaia_dr2_mean_xyzuvw.npy')
    gaia_6d_hist = np.histogramdd(gaia_xyzuvw, bins=bins)
    np.save(gaia_6d_hist_file, gaia_6d_hist)
if BG_DENS != 0.:
    bg_ln_ols = np.log(np.zeros(nstars) + BG_DENS)
else:
    bg_ln_ols = None

while ncomps < MAX_COMP:
    # handle special case of one component
    if ncomps == 1:
        logging.info("******************************************")
        logging.info("*********  FITTING 1 COMPONENT  **********")
        logging.info("******************************************")
        run_dir = rdir + '{}/'.format(ncomps)
        mkpath(run_dir)

        try:
            new_groups = dt.loadGroups(run_dir + 'final/final_groups.npy')
            new_meds = np.load(run_dir + 'final/final_med_errs.npy')
            new_z = np.load(run_dir + 'final/final_membership.npy')
            logging.info("Loaded from previous run")
        except IOError:
            new_groups, new_meds, new_z =\
                em.fit_many_comps(star_pars, ncomps, rdir=run_dir, pool=pool,
                                  bg_dens=BG_DENS,
                                  )
            new_groups = np.array(new_groups)

        new_lnlike = em.get_overall_lnlikelihood(star_pars, new_groups,
                                                 bg_ln_ols=bg_ln_ols)
        new_lnpost = em.get_overall_lnlikelihood(star_pars, new_groups,
                                                 bg_ln_ols=bg_ln_ols,
                                                 inc_posterior=True)
Beispiel #9
0
logging.info("{} tolerance".format(C_TOL))
logging.info("In the directory: {}".format(rdir))

## Destination: (inspired by LCC)
mean_now_lsr = np.array([50., -100., 25., 1.1, -7.76, 2.25])

# Calculate appropriate starting point
mean_then = torb.trace_cartesian_orbit(mean_now_lsr, -age)
# gather inputs
# group_pars = np.hstack((mean_then, dX, dV, age, nstars))

# Setting up perfect current xyzuvw values
try:
    xyzuvw_now_perf = np.load(rdir+xyzuvw_perf_file)
    # origin = np.load(group_savefile)
    origin = dt.loadGroups(rdir+group_savefile)[0]
    logging.basicConfig(
        level=logging.INFO, filemode='a',
        filename='my_investigator_demo.log',
    )
    logging.info("appending to previous attempt")
except IOError:
    logging.basicConfig(
        level=logging.INFO, filemode='w',
        filename='my_investigator_demo.log',
    )
    logging.info("Beginning fresh run:")
    logging.info("Input arguments: {}".format(sys.argv[1:]))
    logging.info("\n"
                 "\tage:     {}\n"
                 "\tprecs:   {}".format(
Beispiel #10
0
ERROR = 1.0
group_pars = np.array([
    [ 26.11, 38.20, 23.59,  0.34, -3.87, -0.36,  2.23, -0.05,  14.32, 35],
    [ 35.25, 52.58, 24.00, -2.80, -2.85,  0.25,  3.07, -0.27, 16.72, 25],
])
ngroups = group_pars.shape[0]

# Get background stars

# get typical background density

# nbgstars = 10
try:
    all_xyzuvw_now_perf = np.load(xyzuvw_perf_file)
    origins = dt.loadGroups(groups_savefile)
    star_pars = dt.loadXYZUVW(xyzuvw_conv_savefile)
    logging.info("Loaded synth data from previous run")
except IOError:
    all_xyzuvw_init = np.zeros((0,6))
    all_xyzuvw_now_perf = np.zeros((0,6))
    origins = []
    for i in range(ngroups):
        logging.info(" generating from group {}".format(i))
        # MANUALLY SEPARATE CURRENT DAY DISTROS IN DIMENSION X
        # mean_now_w_offset = mean_now.copy()
        # mean_now_w_offset[0] += i * 50
    
        # mean_then = torb.traceOrbitXYZUVW(mean_now_w_offset, -extra_pars[i,-2],
        #                                   single_age=True)
        xyzuvw_init, origin = syn.synthesiseXYZUVW(group_pars[i], form='sphere',
                # if dim1 == 0 and dim2 == 1 and debugging_circles:
                #     bpmg_range[1] = temp_range

    # To ensure consistency, we now plot the BANYAN bpmg stars only,
    # and use the ragnes from previous plot
    # 2019-07-13 [TC]: not sure why this is here...
    if False:
        fit_name = 'banyan_bpmg'
        rdir = '../../results/em_fit/beta_Pictoris/'

        memb_file = rdir + 'final_membership.npy'
        groups_file = rdir + 'final_best_groups.npy'
        star_pars_file = '../../data/beta_Pictoris_with_gaia_small_xyzuvw.fits'

        z = np.load(memb_file)
        groups = dt.loadGroups(groups_file)
        star_pars = dt.loadDictFromTable(star_pars_file, 'beta Pictoris')
        nstars = len(star_pars['xyzuvw'])

        # First do all, then just do possible membs of BPMG
        for dim1, dim2 in DEFAULT_DIMS:  #[(0,1), (0, 3), (1, 4), (2,5)]: #, (2, 5)]:  # , 'yv', 'zw']:
            # if dim1 == 0 and dim2 == 1 and debugging_circles:
            #     temp_range = bpmg_range[1]
            #     bpmg_range[1] = [-120, 80]
            # import pdb; pdb.set_trace()
            fp.plotPane(
                dim1,
                dim2,
                groups=groups,
                star_pars=star_pars,
                group_now=True,
Beispiel #12
0
import sys

import chronostar.likelihood

sys.path.insert(0, '..')

import chronostar.retired2.datatool as dt
import chronostar.retired2.converter as cv
import chronostar.coordinate as cc
import chronostar.expectmax as em

filename = '../data/2M0249-05.fits'
star_tab = Table.read(filename)
final_groups = '../results/em_fit/beta_Pictoris_wgs_inv2_5B_res/' \
               'final_groups.npy'
beta_fit = dt.loadGroups(final_groups)[0]
gaia_xyzuvw_file = '../data/gaia_dr2_mean_xyzuvw.npy'

if np.isnan(star_tab['radial_velocity']):
    print("Its nan")
    # from Shkolnik 2017
    # star_tab['radial_velocity'] = 14.4
    # star_tab['radial_velocity_error'] = 0.4

    star_tab['radial_velocity'] = 16.44
    star_tab['radial_velocity_error'] = 1.

if np.isnan(star_tab['radial_velocity']):
    print("Its nan")

# extend proper motion uncertainty
Beispiel #13
0
np.save(bg_savefile, BG_DENS)

logging.info("  with error fraction {}".format(ERROR))
logging.info("  and background density {}".format(BG_DENS))
# Set a current-day location around which synth stars will end up
mean_now = np.array([50., -100., -0., -10., -20., -5.])

logging.info("Mean (now):\n{}".format(mean_now))
logging.info("Extra pars:\n{}".format(extra_pars))
logging.info("Offsets:\n{}".format(offsets))

try:
    #all_xyzuvw_now_perf = np.load(xyzuvw_perf_file)
    np.load(xyzuvw_perf_file)
    #origins = dt.loadGroups(groups_savefile)
    dt.loadGroups(groups_savefile)
    #star_pars = dt.loadXYZUVW(xyzuvw_conv_savefile)
    dt.loadXYZUVW(xyzuvw_conv_savefile)
    logging.info("Synth data exists! .....")
    print("Synth data exists")
    raise UserWarning
except IOError:
    all_xyzuvw_init = np.zeros((0, 6))
    all_xyzuvw_now_perf = np.zeros((0, 6))
    origins = []
    for i in range(ngroups):
        logging.info(" generating from group {}".format(i))
        # MANUALLY SEPARATE CURRENT DAY DISTROS IN DIMENSION X
        mean_now_w_offset = mean_now.copy()
        # mean_now_w_offset[0] += i * 50
        mean_now_w_offset += offsets[i]
Beispiel #14
0
    Calculates the density of a pdf characterised by a histogram at point x
    """
    # Check if handling 1D histogram
    if len(bin_heights.shape) == 1:
        raise UserWarning
    dims = len(bin_heights.shape)
    bin_widths = [bins[1] - bins[0] for bins in bin_edges]
    bin_area = np.prod(bin_widths)

    x_ix = tuple([np.digitize(x[dim], bin_edges[dim]) - 1
                  for dim in range(dims)])
    return bin_heights[x_ix] / bin_area


rdir = 'example_comps/'
init_groups = dt.loadGroups(rdir+'iter00/best_groups.npy')
init_z = np.load(rdir+'iter00/membership.npy')
later_groups = dt.loadGroups(rdir + 'iter15/best_groups.npy')
later_z = np.load(rdir+'iter15/membership.npy')

xs = np.linspace(2, 100,10000)
plt.clf()
for sig in [.5, .75, 1]:
    plt.plot(xs, gf.lnlognormal(xs, sig=sig), label='{}'.format(sig))
plt.legend(loc='best')
plt.savefig("lnlognormal.pdf")

print("Init alphas")
for i, (igroup, lgroup) in enumerate(zip(init_groups, later_groups)):
    sig=1.
    print('Group {}'.format(i))
Beispiel #15
0
    plt_file = rdir + 'multi_plot_{}_{}_{}_{}_{}_{}.pdf'.format(*scenario)
    print("Checking {}".format(plt_file))
    if not os.path.isfile(plt_file):
        print("Plotting {}".format(plt_file))
        try:
            star_pars_file = rdir + 'xyzuvw_now.fits'
            chain_file = rdir + 'final_chain.npy'
            origins_file = rdir + 'origins.npy'
            lnprob_file = rdir + 'final_lnprob.npy'

            chain = np.load(chain_file).reshape(-1, 9)
            lnprob = np.load(lnprob_file)
            best_pars = chain[np.argmax(lnprob_file)]
            best_fit = chronostar.component.Component(best_pars, internal=True)
            origins = dt.loadGroups(origins_file)

            star_pars = dt.loadXYZUVW(star_pars_file)

            fp.plotMultiPane(
                ['xy', 'xz', 'uv', 'xu', 'yv', 'zw'],
                star_pars,
                [best_fit],
                origins=origins,
                save_file=rdir +
                'multi_plot_{}_{}_{}_{}_{}_{}.pdf'.format(*scenario),
                title='{}Myr, {}pc, {}km/s, {} stars, {}, {}'.format(
                    *scenario),
            )
            print("done")
        except:
Beispiel #16
0
def plotPane(dim1=0,
             dim2=1,
             ax=None,
             groups=(),
             star_pars=None,
             origin_star_pars=None,
             star_orbits=False,
             origins=None,
             group_then=False,
             group_now=False,
             group_orbit=False,
             annotate=False,
             membership=None,
             true_memb=None,
             savefile='',
             with_bg=False,
             markers=None,
             group_bg=False,
             marker_labels=None,
             color_labels=None,
             marker_style=None,
             marker_legend=None,
             color_legend=None,
             star_pars_label=None,
             origin_star_pars_label=None,
             range_1=None,
             range_2=None,
             isotropic=False,
             ordering=None,
             no_bg_covs=False):
    """
    Plots a single pane capturing kinematic info in any desired 2D plane

    Uses global constants COLORS and HATCHES to inform consistent colour
    scheme.
    Can use this to plot different panes of one whole figure

    Parameters
    ----------
    dim1: x-axis, can either be integer 0-5 (inclusive) or a letter form
          'xyzuvw' (either case)
    dim2: y-axis, same conditions as dim1
    ax:   the axes object on which to plot (defaults to pyplots currnet axes)
    groups: a list of (or just one) synthesiser.Group objects, corresponding
            to the fit of the origin(s)
    star_pars:  dict object with keys 'xyzuvw' ([nstars,6] array of current
                star means) and 'xyzuvw_cov' ([nstars,6,6] array of current
                star covariance matrices)
    star_orbits: (bool) plot the calculated stellar traceback orbits of
                        central estimate of measurements
    group_then: (bool) plot the group's origin
    group_now:  (bool) plot the group's current day distribution
    group_orbit: (bool) plot the trajectory of the group's mean
    annotate: (bool) add text describing the figure's contents
    with_bg: (bool) treat the last column in Z as members of background, and
            color accordingly
    no_bg_covs: (bool) ignore covariance matrices of stars fitted to background

    Returns
    -------
    (nothing returned)
    """
    labels = 'XYZUVW'
    units = 3 * ['pc'] + 3 * ['km/s']

    if savefile:
        plt.clf()

    # Tidying up inputs
    if ax is None:
        ax = plt.gca()
    if type(dim1) is not int:
        dim1 = labels.index(dim1.upper())
    if type(dim2) is not int:
        dim2 = labels.index(dim2.upper())
    if type(star_pars) is str:
        star_pars = dt.loadXYZUVW(star_pars)
    if type(membership) is str:
        membership = np.load(membership)
    if type(groups) is str:
        groups = dt.loadGroups(groups)
    if marker_style is None:
        marker_style = MARKERS[:]
    # if type(origin_star_pars) is str:
    #     origin_star_pars = dt.loadXYZUVW(origin_star_pars)

    legend_pts = []
    legend_labels = []

    # ensure groups is iterable
    try:
        len(groups)
    except:
        groups = [groups]
    ngroups = len(groups)
    if ordering is None:
        ordering = range(len(marker_style))

    # plot stellar data (positions with errors and optionally traceback
    # orbits back to some ill-defined age
    if star_pars:
        nstars = star_pars['xyzuvw'].shape[0]

        # apply default color and markers, to be overwritten if needed
        pt_colors = np.array(nstars * [COLORS[0]])
        if markers is None:
            markers = np.array(nstars * ['.'])

        # Incorporate fitted membership into colors of the pts
        if membership is not None:
            best_mship = np.argmax(membership[:, :ngroups + with_bg], axis=1)
            pt_colors = np.array(COLORS[:ngroups] +
                                 with_bg * ['xkcd:grey'])[best_mship]
            # Incoporate "True" membership into pt markers
            if true_memb is not None:
                markers = np.array(MARKERS)[np.argmax(true_memb, axis=1)]
                if with_bg:
                    true_bg_mask = np.where(true_memb[:, -1] == 1.)
                    markers[true_bg_mask] = '.'
        all_mark_size = np.array(nstars * [MARK_SIZE])

        # group_bg handles case where background is fitted to by final component
        if with_bg:
            all_mark_size[np.where(
                np.argmax(membership, axis=1) == ngroups -
                group_bg)] = BG_MARK_SIZE

        mns = star_pars['xyzuvw']
        try:
            covs = np.copy(star_pars['xyzuvw_cov'])
            # replace background cov matrices with None so as to avoid plotting
            if with_bg and no_bg_covs:
                print("Discarding background cov-mats")
                # import pdb; pdb.set_trace()
                covs[np.where(
                    np.argmax(membership, axis=1) == ngroups -
                    group_bg)] = None
        except KeyError:
            covs = len(mns) * [None]
            star_pars['xyzuvw_cov'] = covs
        st_count = 0
        for star_mn, star_cov, marker, pt_color, m_size in zip(
                mns, covs, markers, pt_colors, all_mark_size):
            pt = ax.scatter(
                star_mn[dim1],
                star_mn[dim2],
                s=m_size,  #s=MARK_SIZE,
                color=pt_color,
                marker=marker,
                alpha=PT_ALPHA,
                linewidth=0.0,
            )
            # plot uncertainties
            if star_cov is not None:
                plotCovEllipse(
                    star_cov[np.ix_([dim1, dim2], [dim1, dim2])],
                    star_mn[np.ix_([dim1, dim2])],
                    ax=ax,
                    alpha=COV_ALPHA,
                    linewidth='0.1',
                    color=pt_color,
                )
            # plot traceback orbits for as long as oldest group (if known)
            # else, 30 Myr
            if star_orbits and st_count % 3 == 0:
                try:
                    tb_limit = max([g.age for g in groups])
                except:
                    tb_limit = 30
                plotOrbit(star_mn,
                          dim1,
                          dim2,
                          ax,
                          end_age=-tb_limit,
                          color='xkcd:grey')
            st_count += 1
        if star_pars_label:
            # ax.legend(numpoints=1)
            legend_pts.append(pt)
            legend_labels.append(star_pars_label)

        if origin_star_pars is not None:
            for star_mn, marker, pt_color, m_size in\
                    zip(origin_star_pars['xyzuvw'],
                        # origin_star_pars['xyzuvw_cov'],
                        markers, pt_colors, all_mark_size):
                pt = ax.scatter(
                    star_mn[dim1],
                    star_mn[dim2],
                    s=0.5 * m_size,
                    # s=MARK_SIZE,
                    color=pt_color,
                    marker='s',
                    alpha=PT_ALPHA,
                    linewidth=0.0,  #label=origin_star_pars_label,
                )
                # # plot uncertainties
                # if star_cov is not None:
                #     ee.plotCovEllipse(
                #         star_cov[np.ix_([dim1, dim2], [dim1, dim2])],
                #         star_mn[np.ix_([dim1, dim2])],
                #         ax=ax, alpha=0.05, linewidth='0.1',
                #         color=pt_color,
                #         )
            if origin_star_pars_label:
                legend_pts.append(pt)
                legend_labels.append(origin_star_pars_label)

    # plot info for each group (fitted, or true synthetic origin)
    for i, group in enumerate(groups):
        cov_then = group.get_covmatrix()
        mean_then = group.get_mean()
        # plot group initial distribution
        if group_then:
            ax.plot(mean_then[dim1],
                    mean_then[dim2],
                    marker='+',
                    alpha=0.3,
                    color=COLORS[i])
            plotCovEllipse(cov_then[np.ix_([dim1, dim2], [dim1, dim2])],
                           mean_then[np.ix_([dim1, dim2])],
                           with_line=True,
                           ax=ax,
                           alpha=0.3,
                           ls='--',
                           color=COLORS[i])
            if annotate:
                ax.annotate(r'$\mathbf{\mu}_0, \mathbf{\Sigma}_0$',
                            (mean_then[dim1], mean_then[dim2]),
                            color=COLORS[i])

        # plot group current day distribution (should match well with stars)
        if group_now:
            mean_now = torb.trace_cartesian_orbit(mean_then,
                                                  group.get_age(),
                                                  single_age=True)
            cov_now = tf.transform_covmatrix(cov_then,
                                             torb.trace_cartesian_orbit,
                                             mean_then,
                                             args=[group.get_age()])
            ax.plot(mean_now[dim1],
                    mean_now[dim2],
                    marker='+',
                    alpha=0.3,
                    color=COLORS[i])
            plotCovEllipse(
                cov_now[np.ix_([dim1, dim2], [dim1, dim2])],
                mean_now[np.ix_([dim1, dim2])],
                # with_line=True,
                ax=ax,
                alpha=0.4,
                ls='-.',
                ec=COLORS[i],
                fill=False,
                hatch=HATCHES[i],
                color=COLORS[i])
            if annotate:
                ax.annotate(r'$\mathbf{\mu}_c, \mathbf{\Sigma}_c$',
                            (mean_now[dim1], mean_now[dim2]),
                            color=COLORS[i])

        # plot orbit of mean of group
        if group_orbit:
            plotOrbit(mean_now,
                      dim1,
                      dim2,
                      ax,
                      -group.age,
                      group_ix=i,
                      with_arrow=True,
                      annotate=annotate)
    if origins:
        for origin in origins:
            cov_then = origin.generateSphericalCovMatrix()
            mean_then = origin.mean
            # plot origin initial distribution
            ax.plot(mean_then[dim1],
                    mean_then[dim2],
                    marker='+',
                    color='xkcd:grey')
            plotCovEllipse(cov_then[np.ix_([dim1, dim2], [dim1, dim2])],
                           mean_then[np.ix_([dim1, dim2])],
                           with_line=True,
                           ax=ax,
                           alpha=0.1,
                           ls='--',
                           color='xkcd:grey')

    ax.set_xlabel("{} [{}]".format(labels[dim1], units[dim1]))
    ax.set_ylabel("{} [{}]".format(labels[dim2], units[dim2]))

    # NOT QUITE....
    # if marker_legend is not None and color_legend is not None:
    #     x_loc = np.mean(star_pars['xyzuvw'][:,dim1])
    #     y_loc = np.mean(star_pars['xyzuvw'][:,dim2])
    #     for label in marker_legend.keys():
    #         ax.plot(x_loc, y_loc, color=color_legend[label],
    #                 marker=marker_legend[label], alpha=0, label=label)
    #     ax.legend(loc='best')

    # if star_pars_label is not None:
    #     ax.legend(numpoints=1, loc='best')
    # ax.legend(loc='best')

    # if marker_order is not None:
    #     for label_ix, marker_ix in enumerate(marker_order):
    #         axleg.scatter(0,0,color='black',marker=MARKERS[marker_ix],
    #                       label=MARKER_LABELS[label_ix])
    # #
    # if len(legend_pts) > 0:
    #     ax.legend(legend_pts, legend_labels)

    # update fontsize
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
                 ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(FONTSIZE)
    if ax.get_legend() is not None:
        for item in ax.get_legend().get_texts():
            item.set_fontsize(FONTSIZE)

    # ensure we have some handle on the ranges
    # if range_1 is None:
    #     range_1 = ax.get_xlim()
    # if range_2 is None:
    #     range_2 = ax.get_ylim()

    if range_2:
        ax.set_ylim(range_2)

    if isotropic:
        print("Setting isotropic for dims {} and {}".format(dim1, dim2))
        # plt.gca().set_aspect('equal', adjustable='box')
        # import pdb; pdb.set_trace()
        plt.gca().set_aspect('equal', adjustable='datalim')

        # manually calculate what the new xaxis must be...
        figW, figH = ax.get_figure().get_size_inches()
        xmid = (ax.get_xlim()[1] + ax.get_xlim()[0]) * 0.5
        yspan = ax.get_ylim()[1] - ax.get_ylim()[0]
        xspan = figW * yspan / figH

        # check if this increases span
        if xspan > ax.get_xlim()[1] - ax.get_xlim()[0]:
            ax.set_xlim(xmid - 0.5 * xspan, xmid + 0.5 * xspan)
        # if not, need to increase yspan
        else:
            ymid = (ax.get_ylim()[1] + ax.get_ylim()[0]) * 0.5
            xspan = ax.get_xlim()[1] - ax.get_xlim()[0]
            yspan = figH * xspan / figW
            ax.set_ylim(ymid - 0.5 * yspan, ymid + 0.5 * yspan)

        # import pdb; pdb.set_trace()
    elif range_1:
        ax.set_xlim(range_1)

    if color_labels is not None:
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        for i, color_label in enumerate(color_labels):
            ax.plot(1e10, 1e10, color=COLORS[i], label=color_label)
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.legend(loc='best')

    if marker_labels is not None:
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        # import pdb; pdb.set_trace()
        for i, marker_label in enumerate(marker_labels):

            ax.scatter(
                1e10,
                1e10,
                c='black',
                marker=np.array(marker_style)[ordering][i],
                # marker=MARKERS[list(marker_labels).index(marker_label)],
                label=marker_label)
        if with_bg:
            ax.scatter(1e10,
                       1e10,
                       c='xkcd:grey',
                       marker='.',
                       label='Background')
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.legend(loc='best')

    # if marker_legend is not None:
    #     xlim = ax.get_xlim()
    #     ylim = ax.get_ylim()
    #     # import pdb; pdb.set_trace()
    #     for k, v in marker_legend.items():
    #         ax.scatter(1e10, 1e10, c='black',
    #                    marker=v, label=k)
    #     ax.set_xlim(xlim)
    #     ax.set_ylim(ylim)
    #     ax.legend(loc='best')

    if color_legend is not None and marker_legend is not None:
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        # import pdb; pdb.set_trace()
        for label in color_legend.keys():
            ax.scatter(1e10,
                       1e10,
                       c=color_legend[label],
                       marker=marker_legend[label],
                       label=label)
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.legend(loc='best')

    if savefile:
        # set_size(4,2,ax)
        plt.savefig(savefile)
    # import pdb; pdb.set_trace()

    # return ax.get_window_extent(None).width, ax.get_window_extent(None).height

    return ax.get_xlim(), ax.get_ylim()
                           right=True)
        print("... saving")
        plt.savefig(plot_name)

if PLOT_BPMG_REAL:
    for iteration in ['5B']: #, '6C']:
        star_pars_file = '../../data/beta_Pictoris_with_gaia_small_xyzuvw.fits'
        star_pars = dt.loadXYZUVW(star_pars_file)
        fit_name = 'bpmg_and_nearby'
        rdir = '../../results/em_fit/beta_Pictoris_wgs_inv2_{}_res/'.format(iteration)

        memb_file = rdir + 'final_membership.npy'
        groups_file = rdir + 'final_groups.npy'

        z = np.load(memb_file)
        groups = dt.loadGroups(groups_file)

        # Assign markers based on BANYAN membership
        gt_sp = dt.loadDictFromTable('../../data/banyan_with_gaia_near_bpmg_xyzuvw.fits')
        banyan_membership = len(star_pars['xyzuvw']) * ['N/A']
        for i in range(len(star_pars['xyzuvw'])):
            master_table_ix = np.where(gt_sp['table']['source_id']==star_pars['gaia_ids'][i])
            banyan_membership[i] = gt_sp['table']['Moving group'][master_table_ix[0][0]]

        # assign markers based on present moving groups, keep track of
        # assoc -> marker relationship incase a legend is called for
        banyan_membership=np.array(banyan_membership)
        banyan_markers = np.array(len(banyan_membership) * ['.'])

        banyan_memb_set = set(banyan_membership)
        for bassoc in set(gt_sp['table']['Moving group']):
np.save(bg_savefile, BG_DENS)

logging.info("  with error fraction {}".format(ERROR))
logging.info("  and background density {}".format(BG_DENS))
# Set a current-day location around which synth stars will end up
mean_now = np.array([50., -100., -0., -10., -20., -5.])

logging.info("Mean (now):\n{}".format(mean_now))
logging.info("Extra pars:\n{}".format(extra_pars))
logging.info("Offsets:\n{}".format(offsets))

try:
    #all_xyzuvw_now_perf = np.load(xyzuvw_perf_file)
    np.load(xyzuvw_perf_file)
    #origins = dt.loadGroups(groups_savefile)
    dt.loadGroups(groups_savefile)
    #star_pars = dt.loadXYZUVW(xyzuvw_conv_savefile)
    dt.loadXYZUVW(xyzuvw_conv_savefile)
    logging.info("Synth data exists! .....")
    print("Synth data exists")
    raise UserWarning
except IOError:
    all_xyzuvw_init = np.zeros((0,6))
    all_xyzuvw_now_perf = np.zeros((0,6))
    origins = []
    for i in range(ngroups):
        logging.info(" generating from group {}".format(i))
        # MANUALLY SEPARATE CURRENT DAY DISTROS IN DIMENSION X
        mean_now_w_offset = mean_now.copy()
        # mean_now_w_offset[0] += i * 50
        mean_now_w_offset += offsets[i]
Beispiel #19
0
        final_med_errs.append(fmed_errs)
    else:
        comp_fgroups = []
        comp_fmembs = []
        comp_fmed_errs = []
        for i in range(ncomps-1):
            subrdir = rdir + '{}/{}/final/'.format(ncomps, chr(ord('A') + i))
            if os.path.isdir(subrdir):
                fgroups, fmembs, fmed_errs = loadFinalResults(subrdir)
                comp_fgroups.append(fgroups)
                comp_fmembs.append(fmembs)
                comp_fmed_errs.append(fmed_errs)
        final_fits.append(comp_fgroups)
        final_membs.append(comp_fmembs)
        final_med_errs.append(comp_fmed_errs)
    ncomps += 1

# load in the synthetic data
sdir = rdir + 'synth_data/'
if os.path.isdir(sdir):
    origins = dt.loadGroups(sdir + 'origins.npy')
    perf_xyzuvw = np.load(sdir + 'perf_xyzuvw.npy')
    true_z = dt.getZfromOrigins(origins, star_pars)
    nassoc_stars = np.sum([o.nstars for o in origins])

try:
    bg_ln_ols = np.load(rdir + 'bg_ln_ols.npy')
except IOError:
    print("couldn't find background ln overlaps...")

import sys

import chronostar.likelihood

sys.path.insert(0, '..')

import chronostar.retired2.datatool as dt
import chronostar.retired2.converter as cv
import chronostar.coordinate as cc
import chronostar.expectmax as em

filename = '../data/2M0249-05.fits'
star_tab = Table.read(filename)
final_groups = '../results/em_fit/beta_Pictoris_wgs_inv2_5B_res/' \
               'final_groups.npy'
beta_fit = dt.loadGroups(final_groups)[0]
gaia_xyzuvw_file = '../data/gaia_dr2_mean_xyzuvw.npy'

if np.isnan(star_tab['radial_velocity']):
    print("Its nan")
    # from Shkolnik 2017
    # star_tab['radial_velocity'] = 14.4
    # star_tab['radial_velocity_error'] = 0.4

    star_tab['radial_velocity'] = 16.44
    star_tab['radial_velocity_error'] = 1.

if np.isnan(star_tab['radial_velocity']):
    print("Its nan")

# extend proper motion uncertainty
Beispiel #21
0
def plotPane(dim1=0, dim2=1, ax=None, groups=(), star_pars=None,
             origin_star_pars=None,
             star_orbits=False, origins=None,
             group_then=False, group_now=False, group_orbit=False,
             annotate=False, membership=None, true_memb=None,
             savefile='', with_bg=False, markers=None, group_bg=False,
             marker_labels=None, color_labels=None,
             marker_style=None,
             marker_legend=None, color_legend=None,
             star_pars_label=None, origin_star_pars_label=None,
             range_1=None, range_2=None, isotropic=False,
             ordering=None, no_bg_covs=False):
    """
    Plots a single pane capturing kinematic info in any desired 2D plane

    Uses global constants COLORS and HATCHES to inform consistent colour
    scheme.
    Can use this to plot different panes of one whole figure

    Parameters
    ----------
    dim1: x-axis, can either be integer 0-5 (inclusive) or a letter form
          'xyzuvw' (either case)
    dim2: y-axis, same conditions as dim1
    ax:   the axes object on which to plot (defaults to pyplots currnet axes)
    groups: a list of (or just one) synthesiser.Group objects, corresponding
            to the fit of the origin(s)
    star_pars:  dict object with keys 'xyzuvw' ([nstars,6] array of current
                star means) and 'xyzuvw_cov' ([nstars,6,6] array of current
                star covariance matrices)
    star_orbits: (bool) plot the calculated stellar traceback orbits of
                        central estimate of measurements
    group_then: (bool) plot the group's origin
    group_now:  (bool) plot the group's current day distribution
    group_orbit: (bool) plot the trajectory of the group's mean
    annotate: (bool) add text describing the figure's contents
    with_bg: (bool) treat the last column in Z as members of background, and
            color accordingly
    no_bg_covs: (bool) ignore covariance matrices of stars fitted to background

    Returns
    -------
    (nothing returned)
    """
    labels = 'XYZUVW'
    units = 3 * ['pc'] + 3 * ['km/s']

    if savefile:
        plt.clf()

    # Tidying up inputs
    if ax is None:
        ax = plt.gca()
    if type(dim1) is not int:
        dim1 = labels.index(dim1.upper())
    if type(dim2) is not int:
        dim2 = labels.index(dim2.upper())
    if type(star_pars) is str:
        star_pars = dt.loadXYZUVW(star_pars)
    if type(membership) is str:
        membership = np.load(membership)
    if type(groups) is str:
        groups = dt.loadGroups(groups)
    if marker_style is None:
        marker_style = MARKERS[:]
    # if type(origin_star_pars) is str:
    #     origin_star_pars = dt.loadXYZUVW(origin_star_pars)

    legend_pts = []
    legend_labels = []

    # ensure groups is iterable
    try:
        len(groups)
    except:
        groups = [groups]
    ngroups = len(groups)
    if ordering is None:
        ordering = range(len(marker_style))

    # plot stellar data (positions with errors and optionally traceback
    # orbits back to some ill-defined age
    if star_pars:
        nstars = star_pars['xyzuvw'].shape[0]

        # apply default color and markers, to be overwritten if needed
        pt_colors = np.array(nstars * [COLORS[0]])
        if markers is None:
            markers = np.array(nstars * ['.'])

        # Incorporate fitted membership into colors of the pts
        if membership is not None:
            best_mship = np.argmax(membership[:,:ngroups+with_bg], axis=1)
            pt_colors = np.array(COLORS[:ngroups] + with_bg*['xkcd:grey'])[best_mship]
            # Incoporate "True" membership into pt markers
            if true_memb is not None:
                markers = np.array(MARKERS)[np.argmax(true_memb,
                                                      axis=1)]
                if with_bg:
                    true_bg_mask = np.where(true_memb[:,-1] == 1.)
                    markers[true_bg_mask] = '.'
        all_mark_size = np.array(nstars * [MARK_SIZE])

        # group_bg handles case where background is fitted to by final component
        if with_bg:
            all_mark_size[np.where(np.argmax(membership, axis=1) == ngroups-group_bg)] = BG_MARK_SIZE

        mns = star_pars['xyzuvw']
        try:
            covs = np.copy(star_pars['xyzuvw_cov'])
            # replace background cov matrices with None so as to avoid plotting
            if with_bg and no_bg_covs:
                print("Discarding background cov-mats")
                # import pdb; pdb.set_trace()
                covs[np.where(np.argmax(membership, axis=1) == ngroups-group_bg)] = None
        except KeyError:
            covs = len(mns) * [None]
            star_pars['xyzuvw_cov'] = covs
        st_count = 0
        for star_mn, star_cov, marker, pt_color, m_size in zip(mns, covs, markers, pt_colors,
                                                               all_mark_size):
            pt = ax.scatter(star_mn[dim1], star_mn[dim2], s=m_size, #s=MARK_SIZE,
                            color=pt_color, marker=marker, alpha=PT_ALPHA,
                            linewidth=0.0,
                            )
            # plot uncertainties
            if star_cov is not None:
                plotCovEllipse(star_cov[np.ix_([dim1, dim2], [dim1, dim2])],
                               star_mn[np.ix_([dim1, dim2])],
                               ax=ax, alpha=COV_ALPHA, linewidth='0.1',
                               color=pt_color,
                               )
            # plot traceback orbits for as long as oldest group (if known)
            # else, 30 Myr
            if star_orbits and st_count%3==0:
                try:
                    tb_limit = max([g.age for g in groups])
                except:
                    tb_limit = 30
                plotOrbit(star_mn, dim1, dim2, ax, end_age=-tb_limit,
                          color='xkcd:grey')
            st_count += 1
        if star_pars_label:
            # ax.legend(numpoints=1)
            legend_pts.append(pt)
            legend_labels.append(star_pars_label)

        if origin_star_pars is not None:
            for star_mn, marker, pt_color, m_size in\
                    zip(origin_star_pars['xyzuvw'],
                        # origin_star_pars['xyzuvw_cov'],
                        markers, pt_colors, all_mark_size):
                pt = ax.scatter(star_mn[dim1], star_mn[dim2], s=0.5*m_size,
                           # s=MARK_SIZE,
                           color=pt_color, marker='s', alpha=PT_ALPHA,
                           linewidth=0.0, #label=origin_star_pars_label,
                           )
                # # plot uncertainties
                # if star_cov is not None:
                #     ee.plotCovEllipse(
                #         star_cov[np.ix_([dim1, dim2], [dim1, dim2])],
                #         star_mn[np.ix_([dim1, dim2])],
                #         ax=ax, alpha=0.05, linewidth='0.1',
                #         color=pt_color,
                #         )
            if origin_star_pars_label:
                legend_pts.append(pt)
                legend_labels.append(origin_star_pars_label)


    # plot info for each group (fitted, or true synthetic origin)
    for i, group in enumerate(groups):
        cov_then = group.generateSphericalCovMatrix()
        mean_then = group.mean
        # plot group initial distribution
        if group_then:
            ax.plot(mean_then[dim1], mean_then[dim2], marker='+', alpha=0.3,
                    color=COLORS[i])
            plotCovEllipse(cov_then[np.ix_([dim1, dim2], [dim1, dim2])],
                           mean_then[np.ix_([dim1,dim2])],
                           with_line=True,
                           ax=ax, alpha=0.3, ls='--',
                           color=COLORS[i])
            if annotate:
                ax.annotate(r'$\mathbf{\mu}_0, \mathbf{\Sigma}_0$',
                            (mean_then[dim1],
                             mean_then[dim2]),
                             color=COLORS[i])

        # plot group current day distribution (should match well with stars)
        if group_now:
            mean_now = torb.trace_cartesian_orbit(mean_then, group.age,
                                                  single_age=True)
            cov_now = tf.transform_covmatrix(cov_then, torb.trace_cartesian_orbit,
                                             mean_then, args=[group.age])
            ax.plot(mean_now[dim1], mean_now[dim2], marker='+', alpha=0.3,
                   color=COLORS[i])
            plotCovEllipse(cov_now[np.ix_([dim1, dim2], [dim1, dim2])],
                           mean_now[np.ix_([dim1,dim2])],
                           # with_line=True,
                           ax=ax, alpha=0.4, ls='-.',
                           ec=COLORS[i], fill=False, hatch=HATCHES[i],
                           color=COLORS[i])
            if annotate:
                ax.annotate(r'$\mathbf{\mu}_c, \mathbf{\Sigma}_c$',
                            (mean_now[dim1],mean_now[dim2]),
                            color=COLORS[i])

        # plot orbit of mean of group
        if group_orbit:
            plotOrbit(mean_now, dim1, dim2, ax, -group.age, group_ix=i,
                      with_arrow=True, annotate=annotate)
    if origins:
        for origin in origins:
            cov_then = origin.generateSphericalCovMatrix()
            mean_then = origin.mean
            # plot origin initial distribution
            ax.plot(mean_then[dim1], mean_then[dim2], marker='+',
                    color='xkcd:grey')
            plotCovEllipse(
                cov_then[np.ix_([dim1, dim2], [dim1, dim2])],
                mean_then[np.ix_([dim1, dim2])],
                with_line=True,
                ax=ax, alpha=0.1, ls='--',
                color='xkcd:grey')

    ax.set_xlabel("{} [{}]".format(labels[dim1], units[dim1]))
    ax.set_ylabel("{} [{}]".format(labels[dim2], units[dim2]))

    # NOT QUITE....
    # if marker_legend is not None and color_legend is not None:
    #     x_loc = np.mean(star_pars['xyzuvw'][:,dim1])
    #     y_loc = np.mean(star_pars['xyzuvw'][:,dim2])
    #     for label in marker_legend.keys():
    #         ax.plot(x_loc, y_loc, color=color_legend[label],
    #                 marker=marker_legend[label], alpha=0, label=label)
    #     ax.legend(loc='best')

    # if star_pars_label is not None:
    #     ax.legend(numpoints=1, loc='best')
        # ax.legend(loc='best')

    # if marker_order is not None:
    #     for label_ix, marker_ix in enumerate(marker_order):
    #         axleg.scatter(0,0,color='black',marker=MARKERS[marker_ix],
    #                       label=MARKER_LABELS[label_ix])
    # #
    # if len(legend_pts) > 0:
    #     ax.legend(legend_pts, legend_labels)

    # update fontsize
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
                 ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(FONTSIZE)
    if ax.get_legend() is not None:
        for item in ax.get_legend().get_texts():
            item.set_fontsize(FONTSIZE)

    # ensure we have some handle on the ranges
    # if range_1 is None:
    #     range_1 = ax.get_xlim()
    # if range_2 is None:
    #     range_2 = ax.get_ylim()

    if range_2:
        ax.set_ylim(range_2)

    if isotropic:
        print("Setting isotropic for dims {} and {}".format(dim1, dim2))
        # plt.gca().set_aspect('equal', adjustable='box')
        # import pdb; pdb.set_trace()
        plt.gca().set_aspect('equal', adjustable='datalim')

        # manually calculate what the new xaxis must be...
        figW, figH = ax.get_figure().get_size_inches()
        xmid = (ax.get_xlim()[1] + ax.get_xlim()[0]) * 0.5
        yspan = ax.get_ylim()[1] - ax.get_ylim()[0]
        xspan = figW * yspan / figH

        # check if this increases span
        if xspan > ax.get_xlim()[1] - ax.get_xlim()[0]:
            ax.set_xlim(xmid - 0.5 * xspan, xmid + 0.5 * xspan)
        # if not, need to increase yspan
        else:
            ymid = (ax.get_ylim()[1] + ax.get_ylim()[0]) * 0.5
            xspan = ax.get_xlim()[1] - ax.get_xlim()[0]
            yspan = figH * xspan / figW
            ax.set_ylim(ymid - 0.5*yspan, ymid + 0.5*yspan)

        # import pdb; pdb.set_trace()
    elif range_1:
        ax.set_xlim(range_1)

    if color_labels is not None:
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        for i, color_label in enumerate(color_labels):
            ax.plot(1e10, 1e10, color=COLORS[i], label=color_label)
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.legend(loc='best')

    if marker_labels is not None:
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        # import pdb; pdb.set_trace()
        for i, marker_label in enumerate(marker_labels):

            ax.scatter(1e10, 1e10, c='black',
                       marker=np.array(marker_style)[ordering][i],
                       # marker=MARKERS[list(marker_labels).index(marker_label)],
                       label=marker_label)
        if with_bg:
            ax.scatter(1e10, 1e10, c='xkcd:grey',
                       marker='.', label='Background')
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.legend(loc='best')

    # if marker_legend is not None:
    #     xlim = ax.get_xlim()
    #     ylim = ax.get_ylim()
    #     # import pdb; pdb.set_trace()
    #     for k, v in marker_legend.items():
    #         ax.scatter(1e10, 1e10, c='black',
    #                    marker=v, label=k)
    #     ax.set_xlim(xlim)
    #     ax.set_ylim(ylim)
    #     ax.legend(loc='best')

    if color_legend is not None and marker_legend is not None:
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        # import pdb; pdb.set_trace()
        for label in color_legend.keys():
            ax.scatter(1e10, 1e10, c=color_legend[label],
            marker=marker_legend[label], label=label)
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.legend(loc='best')

    if savefile:
        # set_size(4,2,ax)
        plt.savefig(savefile)
    # import pdb; pdb.set_trace()

    # return ax.get_window_extent(None).width, ax.get_window_extent(None).height

    return ax.get_xlim(), ax.get_ylim()