def getZfromOrigins(origins, star_pars): if type(origins) is str: origins = dt.loadGroups(origins) if type(star_pars) is str: star_pars = dt.loadXYZUVW(star_pars) nstars = star_pars['xyzuvw'].shape[0] ngroups = len(origins) nassoc_stars = np.sum([o.nstars for o in origins]) using_bg = nstars != nassoc_stars z = np.zeros((nstars, ngroups + using_bg)) stars_so_far = 0 # set associaiton members memberships to 1 for i, o in enumerate(origins): z[stars_so_far:stars_so_far+o.nstars, i] = 1. stars_so_far += o.nstars # set remaining stars as members of background if using_bg: z[stars_so_far:,-1] = 1. return z
print("Checking {}".format(plt_file)) if not os.path.isfile(plt_file): print("Plotting {}".format(plt_file)) try: star_pars_file = rdir + 'xyzuvw_now.fits' chain_file = rdir + 'final_chain.npy' origins_file = rdir + 'origins.npy' lnprob_file = rdir + 'final_lnprob.npy' chain = np.load(chain_file).reshape(-1, 9) lnprob = np.load(lnprob_file) best_pars = chain[np.argmax(lnprob_file)] best_fit = chronostar.component.Component(best_pars, internal=True) origins = dt.loadGroups(origins_file) star_pars = dt.loadXYZUVW(star_pars_file) fp.plotMultiPane( ['xy', 'xz', 'uv', 'xu', 'yv', 'zw'], star_pars, [best_fit], origins=origins, save_file=rdir + 'multi_plot_{}_{}_{}_{}_{}_{}.pdf'.format(*scenario), title='{}Myr, {}pc, {}km/s, {} stars, {}, {}'.format( *scenario), ) print("done") except: print("Not ready yet...")
) # Initialize the MPI-based pool used for parallelization. using_mpi = True try: pool = MPIPool() logging.info("Successfully initialised mpi pool") except: #print("MPI doesn't seem to be installed... maybe install it?") logging.info("MPI doesn't seem to be installed... maybe install it?") using_mpi = False pool=None star_pars = dt.loadXYZUVW(xyzuvw_file) np.set_printoptions(suppress=True) # -------------------------------------------------------------------------- # Get grid-based Z membership # -------------------------------------------------------------------------- # Grid xyzuvw = star_pars['xyzuvw'] dmin=np.min(xyzuvw, axis=0)-1 dmax=np.max(xyzuvw, axis=0)+1 grid_u = np.linspace(dmin[3], dmax[3], 10) #[-10000, -60, -20, 20, 60, 10000] grid_v = np.linspace(dmin[4], dmax[4], 10) grid_w = np.linspace(dmin[5], dmax[5], 6)
using_mpi = False pool=None if using_mpi: if not pool.is_master(): print("One thread is going to sleep") # Wait for instructions from the master process. pool.wait() sys.exit(0) print("Only one thread is master") #logging.info(path_msg) print("Master should be working in the directory:\n{}".format(rdir)) star_pars = dt.loadXYZUVW(xyzuvw_file) # GET BACKGROUND LOG OVERLAP DENSITIES: logging.info("Acquiring background overlaps") try: print("In try") logging.info("Calculating background overlaps") logging.info(" -- this step employs scipy's kernel density estimator") logging.info(" -- so could take a few minutes...") bg_ln_ols = np.load(bg_ln_ols_file) print("could load") assert len(bg_ln_ols) == len(star_pars['xyzuvw']) logging.info("Loaded bg_ln_ols from file") except (IOError, AssertionError): print("in except") bg_ln_ols = dt.getKernelDensities(gaia_xyzuvw_file, star_pars['xyzuvw'], amp_scale=0.001)
# # # # make sure stars are initialised as expected # if can_plot: # for dim1, dim2 in ('xy', 'xu', 'yv', 'zw', 'uv'): # plt.clf() # true_memb = dt.getZfromOrigins(origins, star_pars) # fp.plotPaneWithHists(dim1,dim2,groups=origins, # weights=[origin.nstars for origin in origins], # star_pars=star_pars, # group_now=True, # membership=true_memb, # true_memb=true_memb) # plt.savefig(rdir + 'pre_plot_{}{}.pdf'.format(dim1, dim2)) star_pars = dt.loadXYZUVW(xyzuvw_conv_savefile) MAX_COMP = 6 ncomps = 1 # Set up initial values of results prev_groups = None prev_meds = None prev_lnpost = -np.inf prev_BIC = np.inf prev_lnlike = -np.inf prev_z = None # # # Initialise z # nstars = star_pars['xyzuvw'].shape[0] # nassoc_stars = np.sum([o.nstars for o in origins])
""" from __future__ import print_function, division import matplotlib as mpl mpl.use('Agg') import matplotlib.pyplot as plt import numpy as np import sys sys.path.insert(0, '..') import chronostar.retired2.datatool as dt MARGIN = 2.0 assoc_name = "bpmg_cand_w_gaia_dr2_astrometry_comb_binars" xyzuvw_file = "../data/{}_xyzuvw.fits".format(assoc_name) xyzuvw_dict = dt.loadXYZUVW(xyzuvw_file) # construct box kin_max = np.max(xyzuvw_dict['xyzuvw'], axis=0) kin_min = np.min(xyzuvw_dict['xyzuvw'], axis=0) span = kin_max - kin_min upper_boundary = kin_max + MARGIN*span lower_boundary = kin_min - MARGIN*span # get gaia stars within box gaia_xyzuvw_file = "../data/gaia_dr2_mean_xyzuvw.npy" gaia_xyzuvw = np.load(gaia_xyzuvw_file) mask = np.where( np.all( (gaia_xyzuvw < upper_boundary) & (gaia_xyzuvw > lower_boundary),
def plotMultiPane(dim_pairs, star_pars, groups, origins=None, save_file='dummy.pdf', title=None): """ Flexible function that plots many 2D slices through data and fits Takes as input a list of dimension pairs, stellar data and fitted groups, and will plot each dimension pair in a different pane. TODO: Maybe add functionality to control the data plotted in each pane Parameters ---------- dim_pairs: a list of dimension pairs e.g. [(0,1), (3,4), (0,3), (1,4), (2,5)] ['xz', 'uv', 'zw'] ['XY', 'UV', 'XU', 'YV', 'ZW'] star_pars: either dicitonary of stellar data with keys 'xyzuvw' and 'xyzuvw_cov' or string filename to saved data groups: either a single synthesiser.Group object, a list or array of synthesiser.Group objects, or string filename to data saved as '.npy' file save_file: string, name (and path) of saved plot figure Returns ------- (nothing) """ # Tidying up inputs if type(star_pars) is str: star_pars = dt.loadXYZUVW(star_pars) if type(groups) is str: groups = np.load(groups) # handle case where groups is a single stored object if len(groups.shape) == 0: groups = groups.item() # ensure groups is iterable try: len(groups) except: # groups is a single group instance groups = [groups] if origins: try: len(origins) except: # origins is a single group instance origins = [origins] # setting up plot dimensions npanes = len(dim_pairs) rows = int(np.sqrt(npanes)) #plots are never taller than wide cols = (npanes + rows - 1) // rows # get enough cols ax_h = 5 ax_w = 5 f, axs = plt.subplots(rows, cols) f.set_size_inches(ax_w * cols, ax_h * rows) # drawing each axes for i, (dim1, dim2) in enumerate(dim_pairs): plotPane(dim1, dim2, axs.flatten()[i], groups=groups, origins=origins, star_pars=star_pars, star_orbits=False, group_then=True, group_now=True, group_orbit=True, annotate=False) if title: f.suptitle(title) if save_file: f.savefig(save_file, format='pdf')
def plotPaneWithHists(dim1, dim2, fignum=None, groups=[], weights=None, star_pars=None, star_orbits=False, group_then=False, group_now=False, group_orbit=False, annotate=False, bg_hists=None, membership=None, true_memb=None, savefile='', with_bg=False, range_1=None, range_2=None, residual=False, markers=None, group_bg=False, isotropic=False, color_labels=[], marker_labels=[], marker_order=[], ordering=None, no_bg_covs=False): """ Plot a 2D projection of data and fit along with flanking 1D projections. Uses global constants COLORS and HATCHES to inform consistent colour scheme. Can use this to plot different panes of one whole figure TODO: Incorporate Z TODO: incoporate background histogram Parameters ---------- dim1: x-axis, can either be integer 0-5 (inclusive) or a letter form 'xyzuvw' (either case) dim2: y-axis, same conditions as dim1 fignum: figure number in which to create the plot groups: a list of (or just one) synthesiser.Group objects, corresponding to the fit of the origin(s) star_pars: dict object with keys 'xyzuvw' ([nstars,6] array of current star means) and 'xyzuvw_cov' ([nstars,6,6] array of current star covariance matrices) star_orbits: (bool) plot the calculated stellar traceback orbits of central estimate of measurements group_then: (bool) plot the group's origin group_now: (bool) plot the group's current day distribution group_orbit: (bool) plot the trajectory of the group's mean annotate: (bool) add text describing the figure's contents with_bg: (bool) treat the final column of Z as background memberships and color accordingly Returns ------- (nothing returned) """ labels = 'XYZUVW' axes_units = 3 * ['pc'] + 3 * ['km/s'] if type(membership) is str: membership = np.load(membership) if type(star_pars) is str: star_pars = dt.loadXYZUVW(star_pars) # if ordering: # membership = membership[:,ordering] # TODO: clarify what exactly you're trying to do here if weights is None and len(groups) > 0: if len(groups) == 1 and not with_bg: weights = np.array([len(star_pars['xyzuvw'])]) elif membership is not None: weights = membership.sum(axis=0) else: weights = np.ones(len(groups)) / len(groups) if type(dim1) is not int: dim1 = labels.index(dim1.upper()) if type(dim2) is not int: dim2 = labels.index(dim2.upper()) if type(groups) is str: groups = np.load(groups) if len(groups.shape) == 0: groups = np.array(groups.item()) if type(bg_hists) is str: bg_hists = np.load(bg_hists) # Set global plt tick params??? tick_params = {'direction': 'in', 'top': True, 'right': True} plt.tick_params(**tick_params) # Set up plot fig_width = 5 #inch fig_height = 5 #inch fig = plt.figure(fignum, figsize=(fig_width, fig_height)) plt.clf() # gs = gridspec.GridSpec(4, 4) gs = gridspec.GridSpec(4, 4) # Set up some global plot features # fig.set_tight_layout(tight=True) plt.figure() # Plot central pane axcen = plt.subplot(gs[1:, :-1]) xlim, ylim = plotPane(dim1, dim2, ax=axcen, groups=groups, star_pars=star_pars, star_orbits=star_orbits, group_then=group_then, group_now=group_now, group_orbit=group_orbit, annotate=annotate, membership=membership, true_memb=true_memb, with_bg=with_bg, markers=markers, group_bg=group_bg, isotropic=isotropic, range_1=range_1, range_2=range_2, marker_labels=marker_labels, color_labels=color_labels, ordering=ordering, no_bg_covs=no_bg_covs) plt.tick_params(**tick_params) # if range_1: # plt.xlim(range_1) # if range_2: # plt.ylim(range_2) # plt.grid(gridsepc_kw={'wspace': 0, 'hspace': 0}) # plt.sharex(True) # Plot flanking 1D projections # xlim = axcen.get_xlim() axtop = plt.subplot(gs[0, :-1]) axtop.set_xlim(xlim) axtop.set_xticklabels([]) plot1DProjection(dim1, star_pars, groups, weights, ax=axtop, bg_hists=bg_hists, with_bg=with_bg, membership=membership, residual=residual, x_range=xlim) axtop.set_ylabel('Stars per {}'.format(axes_units[dim1])) plt.tick_params(**tick_params) # axcen.set_tick_params(direction='in', top=True, right=True) # ylim = axcen.get_ylim() axright = plt.subplot(gs[1:, -1]) axright.set_ylim(ylim) axright.set_yticklabels([]) plot1DProjection(dim2, star_pars, groups, weights, ax=axright, bg_hists=bg_hists, horizontal=True, with_bg=with_bg, membership=membership, residual=residual, x_range=ylim) axright.set_xlabel('Stars per {}'.format(axes_units[dim2])) # axcen.set_tick_params(direction='in', top=True, right=True) plt.tick_params(**tick_params) # plt.tight_layout(pad=0.7) axleg = plt.subplot(gs[0, -1]) for spine in axleg.spines.values(): spine.set_visible(False) axleg.tick_params(labelbottom='off', labelleft='off', bottom='off', left='off') # import pdb; pdb.set_trace() if False: for label_ix, marker_ix in enumerate(marker_order): axleg.scatter(0, 0, color='black', marker=MARKERS[marker_ix], label=MARKER_LABELS[label_ix]) for i, color_label in enumerate(color_labels): axleg.plot(0, 0, color=COLORS[i], label=color_label) axleg.legend(loc='best', framealpha=1.0) # for i, marker_label in enumerate(marker_labels): # axleg.scatter(0,0,color='black',marker=MARKERS[i],label=marker_label) # pt = axleg.scatter(0,0, label='Dummy') # plt.legend([pt], ["Test"]) # import pdb; pdb.set_trace() if savefile: plt.savefig(savefile) return xlim, ylim
def plotPane(dim1=0, dim2=1, ax=None, groups=(), star_pars=None, origin_star_pars=None, star_orbits=False, origins=None, group_then=False, group_now=False, group_orbit=False, annotate=False, membership=None, true_memb=None, savefile='', with_bg=False, markers=None, group_bg=False, marker_labels=None, color_labels=None, marker_style=None, marker_legend=None, color_legend=None, star_pars_label=None, origin_star_pars_label=None, range_1=None, range_2=None, isotropic=False, ordering=None, no_bg_covs=False): """ Plots a single pane capturing kinematic info in any desired 2D plane Uses global constants COLORS and HATCHES to inform consistent colour scheme. Can use this to plot different panes of one whole figure Parameters ---------- dim1: x-axis, can either be integer 0-5 (inclusive) or a letter form 'xyzuvw' (either case) dim2: y-axis, same conditions as dim1 ax: the axes object on which to plot (defaults to pyplots currnet axes) groups: a list of (or just one) synthesiser.Group objects, corresponding to the fit of the origin(s) star_pars: dict object with keys 'xyzuvw' ([nstars,6] array of current star means) and 'xyzuvw_cov' ([nstars,6,6] array of current star covariance matrices) star_orbits: (bool) plot the calculated stellar traceback orbits of central estimate of measurements group_then: (bool) plot the group's origin group_now: (bool) plot the group's current day distribution group_orbit: (bool) plot the trajectory of the group's mean annotate: (bool) add text describing the figure's contents with_bg: (bool) treat the last column in Z as members of background, and color accordingly no_bg_covs: (bool) ignore covariance matrices of stars fitted to background Returns ------- (nothing returned) """ labels = 'XYZUVW' units = 3 * ['pc'] + 3 * ['km/s'] if savefile: plt.clf() # Tidying up inputs if ax is None: ax = plt.gca() if type(dim1) is not int: dim1 = labels.index(dim1.upper()) if type(dim2) is not int: dim2 = labels.index(dim2.upper()) if type(star_pars) is str: star_pars = dt.loadXYZUVW(star_pars) if type(membership) is str: membership = np.load(membership) if type(groups) is str: groups = dt.loadGroups(groups) if marker_style is None: marker_style = MARKERS[:] # if type(origin_star_pars) is str: # origin_star_pars = dt.loadXYZUVW(origin_star_pars) legend_pts = [] legend_labels = [] # ensure groups is iterable try: len(groups) except: groups = [groups] ngroups = len(groups) if ordering is None: ordering = range(len(marker_style)) # plot stellar data (positions with errors and optionally traceback # orbits back to some ill-defined age if star_pars: nstars = star_pars['xyzuvw'].shape[0] # apply default color and markers, to be overwritten if needed pt_colors = np.array(nstars * [COLORS[0]]) if markers is None: markers = np.array(nstars * ['.']) # Incorporate fitted membership into colors of the pts if membership is not None: best_mship = np.argmax(membership[:, :ngroups + with_bg], axis=1) pt_colors = np.array(COLORS[:ngroups] + with_bg * ['xkcd:grey'])[best_mship] # Incoporate "True" membership into pt markers if true_memb is not None: markers = np.array(MARKERS)[np.argmax(true_memb, axis=1)] if with_bg: true_bg_mask = np.where(true_memb[:, -1] == 1.) markers[true_bg_mask] = '.' all_mark_size = np.array(nstars * [MARK_SIZE]) # group_bg handles case where background is fitted to by final component if with_bg: all_mark_size[np.where( np.argmax(membership, axis=1) == ngroups - group_bg)] = BG_MARK_SIZE mns = star_pars['xyzuvw'] try: covs = np.copy(star_pars['xyzuvw_cov']) # replace background cov matrices with None so as to avoid plotting if with_bg and no_bg_covs: print("Discarding background cov-mats") # import pdb; pdb.set_trace() covs[np.where( np.argmax(membership, axis=1) == ngroups - group_bg)] = None except KeyError: covs = len(mns) * [None] star_pars['xyzuvw_cov'] = covs st_count = 0 for star_mn, star_cov, marker, pt_color, m_size in zip( mns, covs, markers, pt_colors, all_mark_size): pt = ax.scatter( star_mn[dim1], star_mn[dim2], s=m_size, #s=MARK_SIZE, color=pt_color, marker=marker, alpha=PT_ALPHA, linewidth=0.0, ) # plot uncertainties if star_cov is not None: plotCovEllipse( star_cov[np.ix_([dim1, dim2], [dim1, dim2])], star_mn[np.ix_([dim1, dim2])], ax=ax, alpha=COV_ALPHA, linewidth='0.1', color=pt_color, ) # plot traceback orbits for as long as oldest group (if known) # else, 30 Myr if star_orbits and st_count % 3 == 0: try: tb_limit = max([g.age for g in groups]) except: tb_limit = 30 plotOrbit(star_mn, dim1, dim2, ax, end_age=-tb_limit, color='xkcd:grey') st_count += 1 if star_pars_label: # ax.legend(numpoints=1) legend_pts.append(pt) legend_labels.append(star_pars_label) if origin_star_pars is not None: for star_mn, marker, pt_color, m_size in\ zip(origin_star_pars['xyzuvw'], # origin_star_pars['xyzuvw_cov'], markers, pt_colors, all_mark_size): pt = ax.scatter( star_mn[dim1], star_mn[dim2], s=0.5 * m_size, # s=MARK_SIZE, color=pt_color, marker='s', alpha=PT_ALPHA, linewidth=0.0, #label=origin_star_pars_label, ) # # plot uncertainties # if star_cov is not None: # ee.plotCovEllipse( # star_cov[np.ix_([dim1, dim2], [dim1, dim2])], # star_mn[np.ix_([dim1, dim2])], # ax=ax, alpha=0.05, linewidth='0.1', # color=pt_color, # ) if origin_star_pars_label: legend_pts.append(pt) legend_labels.append(origin_star_pars_label) # plot info for each group (fitted, or true synthetic origin) for i, group in enumerate(groups): cov_then = group.get_covmatrix() mean_then = group.get_mean() # plot group initial distribution if group_then: ax.plot(mean_then[dim1], mean_then[dim2], marker='+', alpha=0.3, color=COLORS[i]) plotCovEllipse(cov_then[np.ix_([dim1, dim2], [dim1, dim2])], mean_then[np.ix_([dim1, dim2])], with_line=True, ax=ax, alpha=0.3, ls='--', color=COLORS[i]) if annotate: ax.annotate(r'$\mathbf{\mu}_0, \mathbf{\Sigma}_0$', (mean_then[dim1], mean_then[dim2]), color=COLORS[i]) # plot group current day distribution (should match well with stars) if group_now: mean_now = torb.trace_cartesian_orbit(mean_then, group.get_age(), single_age=True) cov_now = tf.transform_covmatrix(cov_then, torb.trace_cartesian_orbit, mean_then, args=[group.get_age()]) ax.plot(mean_now[dim1], mean_now[dim2], marker='+', alpha=0.3, color=COLORS[i]) plotCovEllipse( cov_now[np.ix_([dim1, dim2], [dim1, dim2])], mean_now[np.ix_([dim1, dim2])], # with_line=True, ax=ax, alpha=0.4, ls='-.', ec=COLORS[i], fill=False, hatch=HATCHES[i], color=COLORS[i]) if annotate: ax.annotate(r'$\mathbf{\mu}_c, \mathbf{\Sigma}_c$', (mean_now[dim1], mean_now[dim2]), color=COLORS[i]) # plot orbit of mean of group if group_orbit: plotOrbit(mean_now, dim1, dim2, ax, -group.age, group_ix=i, with_arrow=True, annotate=annotate) if origins: for origin in origins: cov_then = origin.generateSphericalCovMatrix() mean_then = origin.mean # plot origin initial distribution ax.plot(mean_then[dim1], mean_then[dim2], marker='+', color='xkcd:grey') plotCovEllipse(cov_then[np.ix_([dim1, dim2], [dim1, dim2])], mean_then[np.ix_([dim1, dim2])], with_line=True, ax=ax, alpha=0.1, ls='--', color='xkcd:grey') ax.set_xlabel("{} [{}]".format(labels[dim1], units[dim1])) ax.set_ylabel("{} [{}]".format(labels[dim2], units[dim2])) # NOT QUITE.... # if marker_legend is not None and color_legend is not None: # x_loc = np.mean(star_pars['xyzuvw'][:,dim1]) # y_loc = np.mean(star_pars['xyzuvw'][:,dim2]) # for label in marker_legend.keys(): # ax.plot(x_loc, y_loc, color=color_legend[label], # marker=marker_legend[label], alpha=0, label=label) # ax.legend(loc='best') # if star_pars_label is not None: # ax.legend(numpoints=1, loc='best') # ax.legend(loc='best') # if marker_order is not None: # for label_ix, marker_ix in enumerate(marker_order): # axleg.scatter(0,0,color='black',marker=MARKERS[marker_ix], # label=MARKER_LABELS[label_ix]) # # # if len(legend_pts) > 0: # ax.legend(legend_pts, legend_labels) # update fontsize for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] + ax.get_xticklabels() + ax.get_yticklabels()): item.set_fontsize(FONTSIZE) if ax.get_legend() is not None: for item in ax.get_legend().get_texts(): item.set_fontsize(FONTSIZE) # ensure we have some handle on the ranges # if range_1 is None: # range_1 = ax.get_xlim() # if range_2 is None: # range_2 = ax.get_ylim() if range_2: ax.set_ylim(range_2) if isotropic: print("Setting isotropic for dims {} and {}".format(dim1, dim2)) # plt.gca().set_aspect('equal', adjustable='box') # import pdb; pdb.set_trace() plt.gca().set_aspect('equal', adjustable='datalim') # manually calculate what the new xaxis must be... figW, figH = ax.get_figure().get_size_inches() xmid = (ax.get_xlim()[1] + ax.get_xlim()[0]) * 0.5 yspan = ax.get_ylim()[1] - ax.get_ylim()[0] xspan = figW * yspan / figH # check if this increases span if xspan > ax.get_xlim()[1] - ax.get_xlim()[0]: ax.set_xlim(xmid - 0.5 * xspan, xmid + 0.5 * xspan) # if not, need to increase yspan else: ymid = (ax.get_ylim()[1] + ax.get_ylim()[0]) * 0.5 xspan = ax.get_xlim()[1] - ax.get_xlim()[0] yspan = figH * xspan / figW ax.set_ylim(ymid - 0.5 * yspan, ymid + 0.5 * yspan) # import pdb; pdb.set_trace() elif range_1: ax.set_xlim(range_1) if color_labels is not None: xlim = ax.get_xlim() ylim = ax.get_ylim() for i, color_label in enumerate(color_labels): ax.plot(1e10, 1e10, color=COLORS[i], label=color_label) ax.set_xlim(xlim) ax.set_ylim(ylim) ax.legend(loc='best') if marker_labels is not None: xlim = ax.get_xlim() ylim = ax.get_ylim() # import pdb; pdb.set_trace() for i, marker_label in enumerate(marker_labels): ax.scatter( 1e10, 1e10, c='black', marker=np.array(marker_style)[ordering][i], # marker=MARKERS[list(marker_labels).index(marker_label)], label=marker_label) if with_bg: ax.scatter(1e10, 1e10, c='xkcd:grey', marker='.', label='Background') ax.set_xlim(xlim) ax.set_ylim(ylim) ax.legend(loc='best') # if marker_legend is not None: # xlim = ax.get_xlim() # ylim = ax.get_ylim() # # import pdb; pdb.set_trace() # for k, v in marker_legend.items(): # ax.scatter(1e10, 1e10, c='black', # marker=v, label=k) # ax.set_xlim(xlim) # ax.set_ylim(ylim) # ax.legend(loc='best') if color_legend is not None and marker_legend is not None: xlim = ax.get_xlim() ylim = ax.get_ylim() # import pdb; pdb.set_trace() for label in color_legend.keys(): ax.scatter(1e10, 1e10, c=color_legend[label], marker=marker_legend[label], label=label) ax.set_xlim(xlim) ax.set_ylim(ylim) ax.legend(loc='best') if savefile: # set_size(4,2,ax) plt.savefig(savefile) # import pdb; pdb.set_trace() # return ax.get_window_extent(None).width, ax.get_window_extent(None).height return ax.get_xlim(), ax.get_ylim()
""" from __future__ import print_function, division import matplotlib as mpl mpl.use('Agg') import matplotlib.pyplot as plt import numpy as np import sys sys.path.insert(0, '..') import chronostar.retired2.datatool as dt MARGIN = 2.0 assoc_name = "bpmg_cand_w_gaia_dr2_astrometry_comb_binars" xyzuvw_file = "../data/{}_xyzuvw.fits".format(assoc_name) xyzuvw_dict = dt.loadXYZUVW(xyzuvw_file) # construct box kin_max = np.max(xyzuvw_dict['xyzuvw'], axis=0) kin_min = np.min(xyzuvw_dict['xyzuvw'], axis=0) span = kin_max - kin_min upper_boundary = kin_max + MARGIN * span lower_boundary = kin_min - MARGIN * span # get gaia stars within box gaia_xyzuvw_file = "../data/gaia_dr2_mean_xyzuvw.npy" gaia_xyzuvw = np.load(gaia_xyzuvw_file) mask = np.where( np.all((gaia_xyzuvw < upper_boundary) & (gaia_xyzuvw > lower_boundary), axis=1)) print(mask[0].shape)
# Setting up file names synth_fit = 'fed_stars' rdir = '../../results/archive/fed_fits/20/gaia/' origins_file = rdir + 'origins.npy' origin_comp_file = rdir + 'origin_ellip_comp.npy' chain_file = rdir + 'final_chain.npy' lnprob_file = rdir + 'final_lnprob.npy' star_pars_file = rdir + 'xyzuvw_now.fits' init_xyzuvw_file = rdir + '../xyzuvw_init_offset.npy' perf_xyzuvw_now = rdir + '../perf_xyzuvw.npy' # loading in data best_comp = SphereComponent.get_best_from_chain(chain_file, lnprob_file) origin_comp = EllipComponent.load_raw_components(origin_comp_file)[0] init_xyzuvw = np.load(init_xyzuvw_file) star_pars = dt.loadXYZUVW(star_pars_file) perf_mean_now = np.load(perf_xyzuvw_now) original_origin = dt.loadGroups(origins_file)[0] # assigning useful shorthands mns = star_pars['xyzuvw'] covs = star_pars['xyzuvw_cov'] fed_xranges, fed_yranges = calcRanges( {'xyzuvw': np.vstack((star_pars['xyzuvw'], init_xyzuvw))}, sep_axes=True, ) labels = 'XYZUVW' units = 3 * ['[pc]'] + 3 * ['[km/s]']
logging.info(" with error fraction {}".format(ERROR)) logging.info(" and background density {}".format(BG_DENS)) # Set a current-day location around which synth stars will end up mean_now = np.array([50., -100., -0., -10., -20., -5.]) logging.info("Mean (now):\n{}".format(mean_now)) logging.info("Extra pars:\n{}".format(extra_pars)) logging.info("Offsets:\n{}".format(offsets)) try: #all_xyzuvw_now_perf = np.load(xyzuvw_perf_file) np.load(xyzuvw_perf_file) #origins = dt.loadGroups(groups_savefile) dt.loadGroups(groups_savefile) #star_pars = dt.loadXYZUVW(xyzuvw_conv_savefile) dt.loadXYZUVW(xyzuvw_conv_savefile) logging.info("Synth data exists! .....") print("Synth data exists") raise UserWarning except IOError: all_xyzuvw_init = np.zeros((0, 6)) all_xyzuvw_now_perf = np.zeros((0, 6)) origins = [] for i in range(ngroups): logging.info(" generating from group {}".format(i)) # MANUALLY SEPARATE CURRENT DAY DISTROS IN DIMENSION X mean_now_w_offset = mean_now.copy() # mean_now_w_offset[0] += i * 50 mean_now_w_offset += offsets[i] mean_then = torb.trace_cartesian_orbit(mean_now_w_offset,
def plotPaneWithHists(dim1, dim2, fignum=None, groups=[], weights=None, star_pars=None, star_orbits=False, group_then=False, group_now=False, group_orbit=False, annotate=False, bg_hists=None, membership=None, true_memb=None, savefile='', with_bg=False, range_1=None, range_2=None, residual=False, markers=None, group_bg=False, isotropic=False, color_labels=[], marker_labels=[], marker_order=[], ordering=None, no_bg_covs=False): """ Plot a 2D projection of data and fit along with flanking 1D projections. Uses global constants COLORS and HATCHES to inform consistent colour scheme. Can use this to plot different panes of one whole figure TODO: Incorporate Z TODO: incoporate background histogram Parameters ---------- dim1: x-axis, can either be integer 0-5 (inclusive) or a letter form 'xyzuvw' (either case) dim2: y-axis, same conditions as dim1 fignum: figure number in which to create the plot groups: a list of (or just one) synthesiser.Group objects, corresponding to the fit of the origin(s) star_pars: dict object with keys 'xyzuvw' ([nstars,6] array of current star means) and 'xyzuvw_cov' ([nstars,6,6] array of current star covariance matrices) star_orbits: (bool) plot the calculated stellar traceback orbits of central estimate of measurements group_then: (bool) plot the group's origin group_now: (bool) plot the group's current day distribution group_orbit: (bool) plot the trajectory of the group's mean annotate: (bool) add text describing the figure's contents with_bg: (bool) treat the final column of Z as background memberships and color accordingly Returns ------- (nothing returned) """ labels = 'XYZUVW' axes_units = 3*['pc'] + 3*['km/s'] if type(membership) is str: membership = np.load(membership) if type(star_pars) is str: star_pars = dt.loadXYZUVW(star_pars) # if ordering: # membership = membership[:,ordering] # TODO: clarify what exactly you're trying to do here if weights is None and len(groups) > 0: if len(groups) == 1 and not with_bg: weights = np.array([len(star_pars['xyzuvw'])]) elif membership is not None: weights = membership.sum(axis=0) else: weights = np.ones(len(groups)) / len(groups) if type(dim1) is not int: dim1 = labels.index(dim1.upper()) if type(dim2) is not int: dim2 = labels.index(dim2.upper()) if type(groups) is str: groups = np.load(groups) if len(groups.shape) == 0: groups = np.array(groups.item()) if type(bg_hists) is str: bg_hists = np.load(bg_hists) # Set global plt tick params??? tick_params = {'direction':'in', 'top':True, 'right':True} plt.tick_params(**tick_params) # Set up plot fig_width = 5 #inch fig_height = 5 #inch fig = plt.figure(fignum, figsize=(fig_width,fig_height)) plt.clf() # gs = gridspec.GridSpec(4, 4) gs = gridspec.GridSpec(4, 4) # Set up some global plot features # fig.set_tight_layout(tight=True) plt.figure() # Plot central pane axcen = plt.subplot(gs[1:, :-1]) xlim, ylim = plotPane( dim1, dim2, ax=axcen, groups=groups, star_pars=star_pars, star_orbits=star_orbits, group_then=group_then, group_now=group_now, group_orbit=group_orbit, annotate=annotate, membership=membership, true_memb=true_memb, with_bg=with_bg, markers=markers, group_bg=group_bg, isotropic=isotropic, range_1=range_1, range_2=range_2, marker_labels=marker_labels, color_labels=color_labels, ordering=ordering, no_bg_covs=no_bg_covs) plt.tick_params(**tick_params) # if range_1: # plt.xlim(range_1) # if range_2: # plt.ylim(range_2) # plt.grid(gridsepc_kw={'wspace': 0, 'hspace': 0}) # plt.sharex(True) # Plot flanking 1D projections # xlim = axcen.get_xlim() axtop = plt.subplot(gs[0, :-1]) axtop.set_xlim(xlim) axtop.set_xticklabels([]) plot1DProjection(dim1, star_pars, groups, weights, ax=axtop, bg_hists=bg_hists, with_bg=with_bg, membership=membership, residual=residual, x_range=xlim) axtop.set_ylabel('Stars per {}'.format(axes_units[dim1])) plt.tick_params(**tick_params) # axcen.set_tick_params(direction='in', top=True, right=True) # ylim = axcen.get_ylim() axright = plt.subplot(gs[1:, -1]) axright.set_ylim(ylim) axright.set_yticklabels([]) plot1DProjection(dim2, star_pars, groups, weights, ax=axright, bg_hists=bg_hists, horizontal=True, with_bg=with_bg, membership=membership, residual=residual, x_range=ylim) axright.set_xlabel('Stars per {}'.format(axes_units[dim2])) # axcen.set_tick_params(direction='in', top=True, right=True) plt.tick_params(**tick_params) # plt.tight_layout(pad=0.7) axleg = plt.subplot(gs[0,-1]) for spine in axleg.spines.values(): spine.set_visible(False) axleg.tick_params(labelbottom='off', labelleft='off', bottom='off', left='off') # import pdb; pdb.set_trace() if False: for label_ix, marker_ix in enumerate(marker_order): axleg.scatter(0,0,color='black',marker=MARKERS[marker_ix], label=MARKER_LABELS[label_ix]) for i, color_label in enumerate(color_labels): axleg.plot(0,0,color=COLORS[i],label=color_label) axleg.legend(loc='best', framealpha=1.0) # for i, marker_label in enumerate(marker_labels): # axleg.scatter(0,0,color='black',marker=MARKERS[i],label=marker_label) # pt = axleg.scatter(0,0, label='Dummy') # plt.legend([pt], ["Test"]) # import pdb; pdb.set_trace() if savefile: plt.savefig(savefile) return xlim, ylim
if os.path.isfile(rdir + 'bg_hists.npy'): bg_hists = np.load(rdir + 'bg_hists.npy') else: bg_hists = None is_inc_fit = os.path.isdir(rdir + '1/') is_synth_fit = os.path.isdir(rdir + 'synth_data/') # First, if stars are synthetic, plot true groups true_memb = None if is_synth_fit: origins = np.load(rdir + 'synth_data/origins.npy') true_memb = getZfromOrigins(origins, star_pars_file) with_bg = len(origins) < true_memb.shape[1] assert true_memb.shape[0] == dt.loadXYZUVW(star_pars_file)['xyzuvw'].shape[0] if len(origins.shape) == 0: origins = np.array(origins.item()) weights = np.array([origin.nstars for origin in origins]) for dim1, dim2 in ('xy', 'uv', 'xu', 'yv', 'zw', 'xw'): plt.clf() fp.plotPaneWithHists(dim1, dim2, star_pars=star_pars_file, groups=origins, weights=weights, group_now=True, with_bg=with_bg, no_bg_covs=with_bg, ) plt.savefig(rdir + 'pre_plot_{}{}.pdf'.format(dim1,dim2)) # Now choose if handling incremental fit or plain fit if not is_inc_fit: plotEveryIter(rdir, star_pars_file)
labels=axis_labels, # reverse=True, label_kwargs={'fontsize':'xx-large'}, max_n_ticks=4, ) print("Applying tick parameters") for ax in fig.axes: ax.tick_params(direction='in', labelsize='x-large', top=True, right=True) print("... saving") plt.savefig(plot_name) if PLOT_BPMG_REAL: for iteration in ['5B']: #, '6C']: star_pars_file = '../../data/beta_Pictoris_with_gaia_small_xyzuvw.fits' star_pars = dt.loadXYZUVW(star_pars_file) fit_name = 'bpmg_and_nearby' rdir = '../../results/em_fit/beta_Pictoris_wgs_inv2_{}_res/'.format(iteration) memb_file = rdir + 'final_membership.npy' groups_file = rdir + 'final_groups.npy' z = np.load(memb_file) groups = dt.loadGroups(groups_file) # Assign markers based on BANYAN membership gt_sp = dt.loadDictFromTable('../../data/banyan_with_gaia_near_bpmg_xyzuvw.fits') banyan_membership = len(star_pars['xyzuvw']) * ['N/A'] for i in range(len(star_pars['xyzuvw'])): master_table_ix = np.where(gt_sp['table']['source_id']==star_pars['gaia_ids'][i]) banyan_membership[i] = gt_sp['table']['Moving group'][master_table_ix[0][0]]
using_mpi = False pool=None if using_mpi: if not pool.is_master(): print("One thread is going to sleep") # Wait for instructions from the master process. pool.wait() sys.exit(0) print("Only one thread is master") #logging.info(path_msg) print("Master should be working in the directory:\n{}".format(rdir)) star_pars = dt.loadXYZUVW(xyzuvw_file, assoc_name=ass_name) # GET BACKGROUND LOG OVERLAP DENSITIES: logging.info("Acquiring background overlaps") try: print("In try") logging.info("Calculating background overlaps") logging.info(" -- this step employs scipy's kernel density estimator") logging.info(" -- so could take a few minutes...") bg_ln_ols = np.load(bg_ln_ols_file) print("could load") assert len(bg_ln_ols) == len(star_pars['xyzuvw']) logging.info("Loaded bg_ln_ols from file") except (IOError, AssertionError): print("in except") bg_ln_ols = dt.getKernelDensities(gaia_xyzuvw_file, star_pars['xyzuvw'])
logging.info(" with error fraction {}".format(ERROR)) logging.info(" and background density {}".format(BG_DENS)) # Set a current-day location around which synth stars will end up mean_now = np.array([50., -100., -0., -10., -20., -5.]) logging.info("Mean (now):\n{}".format(mean_now)) logging.info("Extra pars:\n{}".format(extra_pars)) logging.info("Offsets:\n{}".format(offsets)) try: #all_xyzuvw_now_perf = np.load(xyzuvw_perf_file) np.load(xyzuvw_perf_file) #origins = dt.loadGroups(groups_savefile) dt.loadGroups(groups_savefile) #star_pars = dt.loadXYZUVW(xyzuvw_conv_savefile) dt.loadXYZUVW(xyzuvw_conv_savefile) logging.info("Synth data exists! .....") print("Synth data exists") raise UserWarning except IOError: all_xyzuvw_init = np.zeros((0,6)) all_xyzuvw_now_perf = np.zeros((0,6)) origins = [] for i in range(ngroups): logging.info(" generating from group {}".format(i)) # MANUALLY SEPARATE CURRENT DAY DISTROS IN DIMENSION X mean_now_w_offset = mean_now.copy() # mean_now_w_offset[0] += i * 50 mean_now_w_offset += offsets[i] mean_then = torb.trace_cartesian_orbit(mean_now_w_offset, -extra_pars[i, -2],
def plotPane(dim1=0, dim2=1, ax=None, groups=(), star_pars=None, origin_star_pars=None, star_orbits=False, origins=None, group_then=False, group_now=False, group_orbit=False, annotate=False, membership=None, true_memb=None, savefile='', with_bg=False, markers=None, group_bg=False, marker_labels=None, color_labels=None, marker_style=None, marker_legend=None, color_legend=None, star_pars_label=None, origin_star_pars_label=None, range_1=None, range_2=None, isotropic=False, ordering=None, no_bg_covs=False): """ Plots a single pane capturing kinematic info in any desired 2D plane Uses global constants COLORS and HATCHES to inform consistent colour scheme. Can use this to plot different panes of one whole figure Parameters ---------- dim1: x-axis, can either be integer 0-5 (inclusive) or a letter form 'xyzuvw' (either case) dim2: y-axis, same conditions as dim1 ax: the axes object on which to plot (defaults to pyplots currnet axes) groups: a list of (or just one) synthesiser.Group objects, corresponding to the fit of the origin(s) star_pars: dict object with keys 'xyzuvw' ([nstars,6] array of current star means) and 'xyzuvw_cov' ([nstars,6,6] array of current star covariance matrices) star_orbits: (bool) plot the calculated stellar traceback orbits of central estimate of measurements group_then: (bool) plot the group's origin group_now: (bool) plot the group's current day distribution group_orbit: (bool) plot the trajectory of the group's mean annotate: (bool) add text describing the figure's contents with_bg: (bool) treat the last column in Z as members of background, and color accordingly no_bg_covs: (bool) ignore covariance matrices of stars fitted to background Returns ------- (nothing returned) """ labels = 'XYZUVW' units = 3 * ['pc'] + 3 * ['km/s'] if savefile: plt.clf() # Tidying up inputs if ax is None: ax = plt.gca() if type(dim1) is not int: dim1 = labels.index(dim1.upper()) if type(dim2) is not int: dim2 = labels.index(dim2.upper()) if type(star_pars) is str: star_pars = dt.loadXYZUVW(star_pars) if type(membership) is str: membership = np.load(membership) if type(groups) is str: groups = dt.loadGroups(groups) if marker_style is None: marker_style = MARKERS[:] # if type(origin_star_pars) is str: # origin_star_pars = dt.loadXYZUVW(origin_star_pars) legend_pts = [] legend_labels = [] # ensure groups is iterable try: len(groups) except: groups = [groups] ngroups = len(groups) if ordering is None: ordering = range(len(marker_style)) # plot stellar data (positions with errors and optionally traceback # orbits back to some ill-defined age if star_pars: nstars = star_pars['xyzuvw'].shape[0] # apply default color and markers, to be overwritten if needed pt_colors = np.array(nstars * [COLORS[0]]) if markers is None: markers = np.array(nstars * ['.']) # Incorporate fitted membership into colors of the pts if membership is not None: best_mship = np.argmax(membership[:,:ngroups+with_bg], axis=1) pt_colors = np.array(COLORS[:ngroups] + with_bg*['xkcd:grey'])[best_mship] # Incoporate "True" membership into pt markers if true_memb is not None: markers = np.array(MARKERS)[np.argmax(true_memb, axis=1)] if with_bg: true_bg_mask = np.where(true_memb[:,-1] == 1.) markers[true_bg_mask] = '.' all_mark_size = np.array(nstars * [MARK_SIZE]) # group_bg handles case where background is fitted to by final component if with_bg: all_mark_size[np.where(np.argmax(membership, axis=1) == ngroups-group_bg)] = BG_MARK_SIZE mns = star_pars['xyzuvw'] try: covs = np.copy(star_pars['xyzuvw_cov']) # replace background cov matrices with None so as to avoid plotting if with_bg and no_bg_covs: print("Discarding background cov-mats") # import pdb; pdb.set_trace() covs[np.where(np.argmax(membership, axis=1) == ngroups-group_bg)] = None except KeyError: covs = len(mns) * [None] star_pars['xyzuvw_cov'] = covs st_count = 0 for star_mn, star_cov, marker, pt_color, m_size in zip(mns, covs, markers, pt_colors, all_mark_size): pt = ax.scatter(star_mn[dim1], star_mn[dim2], s=m_size, #s=MARK_SIZE, color=pt_color, marker=marker, alpha=PT_ALPHA, linewidth=0.0, ) # plot uncertainties if star_cov is not None: plotCovEllipse(star_cov[np.ix_([dim1, dim2], [dim1, dim2])], star_mn[np.ix_([dim1, dim2])], ax=ax, alpha=COV_ALPHA, linewidth='0.1', color=pt_color, ) # plot traceback orbits for as long as oldest group (if known) # else, 30 Myr if star_orbits and st_count%3==0: try: tb_limit = max([g.age for g in groups]) except: tb_limit = 30 plotOrbit(star_mn, dim1, dim2, ax, end_age=-tb_limit, color='xkcd:grey') st_count += 1 if star_pars_label: # ax.legend(numpoints=1) legend_pts.append(pt) legend_labels.append(star_pars_label) if origin_star_pars is not None: for star_mn, marker, pt_color, m_size in\ zip(origin_star_pars['xyzuvw'], # origin_star_pars['xyzuvw_cov'], markers, pt_colors, all_mark_size): pt = ax.scatter(star_mn[dim1], star_mn[dim2], s=0.5*m_size, # s=MARK_SIZE, color=pt_color, marker='s', alpha=PT_ALPHA, linewidth=0.0, #label=origin_star_pars_label, ) # # plot uncertainties # if star_cov is not None: # ee.plotCovEllipse( # star_cov[np.ix_([dim1, dim2], [dim1, dim2])], # star_mn[np.ix_([dim1, dim2])], # ax=ax, alpha=0.05, linewidth='0.1', # color=pt_color, # ) if origin_star_pars_label: legend_pts.append(pt) legend_labels.append(origin_star_pars_label) # plot info for each group (fitted, or true synthetic origin) for i, group in enumerate(groups): cov_then = group.generateSphericalCovMatrix() mean_then = group.mean # plot group initial distribution if group_then: ax.plot(mean_then[dim1], mean_then[dim2], marker='+', alpha=0.3, color=COLORS[i]) plotCovEllipse(cov_then[np.ix_([dim1, dim2], [dim1, dim2])], mean_then[np.ix_([dim1,dim2])], with_line=True, ax=ax, alpha=0.3, ls='--', color=COLORS[i]) if annotate: ax.annotate(r'$\mathbf{\mu}_0, \mathbf{\Sigma}_0$', (mean_then[dim1], mean_then[dim2]), color=COLORS[i]) # plot group current day distribution (should match well with stars) if group_now: mean_now = torb.trace_cartesian_orbit(mean_then, group.age, single_age=True) cov_now = tf.transform_covmatrix(cov_then, torb.trace_cartesian_orbit, mean_then, args=[group.age]) ax.plot(mean_now[dim1], mean_now[dim2], marker='+', alpha=0.3, color=COLORS[i]) plotCovEllipse(cov_now[np.ix_([dim1, dim2], [dim1, dim2])], mean_now[np.ix_([dim1,dim2])], # with_line=True, ax=ax, alpha=0.4, ls='-.', ec=COLORS[i], fill=False, hatch=HATCHES[i], color=COLORS[i]) if annotate: ax.annotate(r'$\mathbf{\mu}_c, \mathbf{\Sigma}_c$', (mean_now[dim1],mean_now[dim2]), color=COLORS[i]) # plot orbit of mean of group if group_orbit: plotOrbit(mean_now, dim1, dim2, ax, -group.age, group_ix=i, with_arrow=True, annotate=annotate) if origins: for origin in origins: cov_then = origin.generateSphericalCovMatrix() mean_then = origin.mean # plot origin initial distribution ax.plot(mean_then[dim1], mean_then[dim2], marker='+', color='xkcd:grey') plotCovEllipse( cov_then[np.ix_([dim1, dim2], [dim1, dim2])], mean_then[np.ix_([dim1, dim2])], with_line=True, ax=ax, alpha=0.1, ls='--', color='xkcd:grey') ax.set_xlabel("{} [{}]".format(labels[dim1], units[dim1])) ax.set_ylabel("{} [{}]".format(labels[dim2], units[dim2])) # NOT QUITE.... # if marker_legend is not None and color_legend is not None: # x_loc = np.mean(star_pars['xyzuvw'][:,dim1]) # y_loc = np.mean(star_pars['xyzuvw'][:,dim2]) # for label in marker_legend.keys(): # ax.plot(x_loc, y_loc, color=color_legend[label], # marker=marker_legend[label], alpha=0, label=label) # ax.legend(loc='best') # if star_pars_label is not None: # ax.legend(numpoints=1, loc='best') # ax.legend(loc='best') # if marker_order is not None: # for label_ix, marker_ix in enumerate(marker_order): # axleg.scatter(0,0,color='black',marker=MARKERS[marker_ix], # label=MARKER_LABELS[label_ix]) # # # if len(legend_pts) > 0: # ax.legend(legend_pts, legend_labels) # update fontsize for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] + ax.get_xticklabels() + ax.get_yticklabels()): item.set_fontsize(FONTSIZE) if ax.get_legend() is not None: for item in ax.get_legend().get_texts(): item.set_fontsize(FONTSIZE) # ensure we have some handle on the ranges # if range_1 is None: # range_1 = ax.get_xlim() # if range_2 is None: # range_2 = ax.get_ylim() if range_2: ax.set_ylim(range_2) if isotropic: print("Setting isotropic for dims {} and {}".format(dim1, dim2)) # plt.gca().set_aspect('equal', adjustable='box') # import pdb; pdb.set_trace() plt.gca().set_aspect('equal', adjustable='datalim') # manually calculate what the new xaxis must be... figW, figH = ax.get_figure().get_size_inches() xmid = (ax.get_xlim()[1] + ax.get_xlim()[0]) * 0.5 yspan = ax.get_ylim()[1] - ax.get_ylim()[0] xspan = figW * yspan / figH # check if this increases span if xspan > ax.get_xlim()[1] - ax.get_xlim()[0]: ax.set_xlim(xmid - 0.5 * xspan, xmid + 0.5 * xspan) # if not, need to increase yspan else: ymid = (ax.get_ylim()[1] + ax.get_ylim()[0]) * 0.5 xspan = ax.get_xlim()[1] - ax.get_xlim()[0] yspan = figH * xspan / figW ax.set_ylim(ymid - 0.5*yspan, ymid + 0.5*yspan) # import pdb; pdb.set_trace() elif range_1: ax.set_xlim(range_1) if color_labels is not None: xlim = ax.get_xlim() ylim = ax.get_ylim() for i, color_label in enumerate(color_labels): ax.plot(1e10, 1e10, color=COLORS[i], label=color_label) ax.set_xlim(xlim) ax.set_ylim(ylim) ax.legend(loc='best') if marker_labels is not None: xlim = ax.get_xlim() ylim = ax.get_ylim() # import pdb; pdb.set_trace() for i, marker_label in enumerate(marker_labels): ax.scatter(1e10, 1e10, c='black', marker=np.array(marker_style)[ordering][i], # marker=MARKERS[list(marker_labels).index(marker_label)], label=marker_label) if with_bg: ax.scatter(1e10, 1e10, c='xkcd:grey', marker='.', label='Background') ax.set_xlim(xlim) ax.set_ylim(ylim) ax.legend(loc='best') # if marker_legend is not None: # xlim = ax.get_xlim() # ylim = ax.get_ylim() # # import pdb; pdb.set_trace() # for k, v in marker_legend.items(): # ax.scatter(1e10, 1e10, c='black', # marker=v, label=k) # ax.set_xlim(xlim) # ax.set_ylim(ylim) # ax.legend(loc='best') if color_legend is not None and marker_legend is not None: xlim = ax.get_xlim() ylim = ax.get_ylim() # import pdb; pdb.set_trace() for label in color_legend.keys(): ax.scatter(1e10, 1e10, c=color_legend[label], marker=marker_legend[label], label=label) ax.set_xlim(xlim) ax.set_ylim(ylim) ax.legend(loc='best') if savefile: # set_size(4,2,ax) plt.savefig(savefile) # import pdb; pdb.set_trace() # return ax.get_window_extent(None).width, ax.get_window_extent(None).height return ax.get_xlim(), ax.get_ylim()