def run_synth(self, planes=None, force=False): full_synth_dir = os.path.join(STARFISH, self.synth_dir) if not os.path.exists(full_synth_dir): os.makedirs(full_synth_dir) # Use PHAT AST from the outer field (field 0) crowd_path = os.path.join(self.synth_dir, "crowding.dat") full_crowd_path = os.path.join(STARFISH, crowd_path) tbl = PhatAstTable() tbl.write_crowdfile_for_field(full_crowd_path, 0, bands=PHAT_BANDS) crowd = ExtantCrowdingTable(crowd_path) # No extinction, yet young_av = ExtinctionDistribution() old_av = ExtinctionDistribution() rel_extinction = np.ones(len(PHAT_BANDS), dtype=float) for av in (young_av, old_av): av.set_uniform(0.) if planes is None: planes = self.planes.all_planes self.synth = Synth(self.synth_dir, self.builder, self.lockfile, crowd, rel_extinction, young_extinction=young_av, old_extinction=old_av, planes=planes, mass_span=(0.08, 150.), nstars=10000000) existing_synth = len(glob( os.path.join(STARFISH, self.synth_dir, "z*"))) == 0 if existing_synth or force: self.synth.run_synth(n_cpu=4, clean=False)
def run_synth(self, planes=None, force=False): full_synth_dir = os.path.join(STARFISH, self.synth_dir) if not os.path.exists(full_synth_dir): os.makedirs(full_synth_dir) # Use PHAT AST from the outer field (field 0) crowd_path = os.path.join(self.synth_dir, "crowding.dat") full_crowd_path = os.path.join(STARFISH, crowd_path) tbl = PhatAstTable() tbl.write_crowdfile_for_field(full_crowd_path, 0, bands=PHAT_BANDS) crowd = ExtantCrowdingTable(crowd_path) # No extinction, yet young_av = ExtinctionDistribution() old_av = ExtinctionDistribution() rel_extinction = np.ones(len(PHAT_BANDS), dtype=float) for av in (young_av, old_av): av.set_uniform(0.) if planes is None: planes = self.planes.all_planes self.synth = Synth(self.synth_dir, self.builder, self.lockfile, crowd, rel_extinction, young_extinction=young_av, old_extinction=old_av, planes=planes, mass_span=(0.08, 150.), nstars=10000000) existing_synth = len(glob(os.path.join(STARFISH, self.synth_dir, "z*"))) == 0 if existing_synth or force: self.synth.run_synth(n_cpu=4, clean=False)
def run_synth(self, n_cpu=1, config_only=False): full_synth_dir = os.path.join(STARFISH, self.synth_dir) self.synth = Synth(self.synth_dir, self.builder, self.lockfile, self.crowd, self.bands, self.rel_extinction, young_extinction=self.young_av, old_extinction=self.old_av, planes=self.all_planes, mass_span=(0.08, 150.), nstars=10000000) existing_synth = len(glob( os.path.join(full_synth_dir, "z*"))) > 0 if not config_only: if not existing_synth: self.synth.run_synth(n_cpu=n_cpu, clean=False)
class Pipeline(object): """Pipeline for Multi-CMD fitting and comparison""" def __init__(self, brick, root_dir, isoc_args=None, phases=None): super(Pipeline, self).__init__() self.brick = brick self.catalog = Catalog(brick) self.root_dir = root_dir self.isoc_dir = os.path.join(root_dir, 'isoc') self.lib_dir = os.path.join(root_dir, 'lib') self.synth_dir = os.path.join(root_dir, 'synth') self.z_grid = [0.015, 0.019, 0.024] self.get_isochrones(isoc_args=isoc_args, phases=phases) self.build_lockfile() self.planes = PhatPlanes() self.run_synth() self.fits = OrderedDict() self._solution_tables = {} def get_solution_table(self, key): if key not in self._solution_tables: tbl = self.fits[key].solution_table() self._solution_tables[key] = tbl return tbl def get_isochrones(self, isoc_args=None, phases=None): if isoc_args is None: isoc_args = {} if not os.path.exists(os.path.join(STARFISH, self.isoc_dir)): for z in self.z_grid: r_wfc3 = AgeGridRequest(z, min_log_age=6.6, max_log_age=10.13, delta_log_age=0.02, photsys='wfc3_wide', **isoc_args) r_acs = AgeGridRequest(z, min_log_age=6.6, max_log_age=10.13, delta_log_age=0.02, photsys='acs_wfc', **isoc_args) isoc_set = join_isochrone_sets(r_wfc3.isochrone_set, r_acs.isochrone_set, left_bands=WFC3_BANDS, right_bands=ACS_BANDS) for isoc in isoc_set: isoc = Isochrone(isoc) isoc.rename_column('F275W1', 'F275W') if phases is not None: sels = [] for p in phases: sels.append(np.where(isoc['stage'] == p)[0]) s = np.concatenate(sels) isoc = isoc[s] isoc.export_for_starfish(os.path.join( STARFISH, self.isoc_dir), bands=PHAT_BANDS) d = Distance(785 * u.kpc) self.builder = LibraryBuilder(self.isoc_dir, self.lib_dir, nmag=len(PHAT_BANDS), dmod=d.distmod.value, iverb=3) if not os.path.exists(self.builder.full_isofile_path): self.builder.install() def build_lockfile(self): if not os.path.exists(os.path.join(STARFISH, self.synth_dir)): os.makedirs(os.path.join(STARFISH, self.synth_dir)) self.lockfile = Lockfile(self.builder.read_isofile(), self.synth_dir, unbinned=False) # Bin young isochrones young_grid = np.linspace(6.5, 8.95, 10) for i, logage0 in enumerate(young_grid[:-1]): logage0 = logage0 logage1 = young_grid[i + 1] z_str = "0019" mean_age = (logage0 + logage1) / 0.2 name = "z{0}_{1:05.2f}".format(z_str, mean_age) self.lockfile.lock_box(name, (logage0, logage1), (0.014, 0.025)) # Bin old isochrones old_grid = np.arange(1e9, 14 * 1e9, 1e9) for i, age0 in enumerate(old_grid[:-1]): logage0 = np.log10(age0 - 0.05 * 1e9) logage1 = np.log10(old_grid[i + 1]) z_str = "0019" mean_age = (logage0 + logage1) / 0.2 name = "z{0}_{1:05.2f}".format(z_str, mean_age) self.lockfile.lock_box(name, (logage0, logage1), (0.014, 0.025)) def run_synth(self, planes=None, force=False): full_synth_dir = os.path.join(STARFISH, self.synth_dir) if not os.path.exists(full_synth_dir): os.makedirs(full_synth_dir) # Use PHAT AST from the outer field (field 0) crowd_path = os.path.join(self.synth_dir, "crowding.dat") full_crowd_path = os.path.join(STARFISH, crowd_path) tbl = PhatAstTable() tbl.write_crowdfile_for_field(full_crowd_path, 0, bands=PHAT_BANDS) crowd = ExtantCrowdingTable(crowd_path) # No extinction, yet young_av = ExtinctionDistribution() old_av = ExtinctionDistribution() rel_extinction = np.ones(len(PHAT_BANDS), dtype=float) for av in (young_av, old_av): av.set_uniform(0.) if planes is None: planes = self.planes.all_planes self.synth = Synth(self.synth_dir, self.builder, self.lockfile, crowd, rel_extinction, young_extinction=young_av, old_extinction=old_av, planes=planes, mass_span=(0.08, 150.), nstars=10000000) existing_synth = len(glob(os.path.join(STARFISH, self.synth_dir, "z*"))) == 0 if existing_synth or force: self.synth.run_synth(n_cpu=4, clean=False) def fit_planes(self, key, color_planes, phot_colors, redo=False): fit_dir = os.path.join(self.root_dir, key) data_root = os.path.join(fit_dir, "phot.") for plane, (band1, band2) in zip(color_planes, phot_colors): self.catalog.write(band1, band2, data_root, plane.suffix) sfh = SFH(data_root, self.synth, fit_dir, planes=color_planes) if (not os.path.exists(sfh.full_outfile_path)) or redo: sfh.run_sfh() self.fits[key] = sfh def show_isoc_phase_sim_hess(self, fig): opt_sim = self.planes.get_sim_hess(('f475w', 'f814w'), self.synth, self.lockfile) ir_sim = self.planes.get_sim_hess(('f110w', 'f160w'), self.synth, self.lockfile) opt_cmd = self.planes[('f475w', 'f814w')] ir_cmd = self.planes[('f110w', 'f160w')] gs = gridspec.GridSpec(2, 3, wspace=0.4, bottom=0.2, width_ratios=[1., 1., 0.1]) ax_opt = fig.add_subplot(gs[0, 0]) ax_ir = fig.add_subplot(gs[0, 1]) ax_obs_opt = fig.add_subplot(gs[1, 0]) ax_obs_ir = fig.add_subplot(gs[1, 1]) cb_ax = fig.add_subplot(gs[1, 2]) plot_hess(ax_opt, opt_sim.hess, opt_cmd, opt_sim.origin, imshow_args=None) plot_hess(ax_ir, ir_sim.hess, ir_cmd, ir_sim.origin, imshow_args=None) c = self.catalog.data['f475w_vega'] - self.catalog.data['f814w_vega'] contour_hess(ax_obs_opt, c, self.catalog.data['f814w_vega'], opt_cmd.x_span, opt_cmd.y_span, plot_args={'ms': 3}) plot_isochrone_phases(ax_obs_opt, 'F475W', 'F814W', show_cb=False) # opt_cmd.plot_mask(ax_obs_opt) ax_obs_opt.set_xlabel(opt_cmd.x_label) ax_obs_opt.set_ylabel(opt_cmd.y_label) ax_obs_opt.set_xlim(opt_cmd.xlim) ax_obs_opt.set_ylim(opt_cmd.ylim) c = self.catalog.data['f110w_vega'] - self.catalog.data['f160w_vega'] contour_hess(ax_obs_ir, c, self.catalog.data['f160w_vega'], ir_cmd.x_span, ir_cmd.y_span, plot_args={'ms': 3}) plot_isochrone_phases(ax_obs_ir, 'F110W', 'F160W', show_cb=True, cb_ax=cb_ax) # ir_cmd.plot_mask(ax_obs_ir) ax_obs_ir.set_xlabel(ir_cmd.x_label) ax_obs_ir.set_ylabel(ir_cmd.y_label) ax_obs_ir.set_xlim(ir_cmd.xlim) ax_obs_ir.set_ylim(ir_cmd.ylim) fig.show() def plot_contour_hess(self, ax, bands, plane_key): plane = self.planes[plane_key] c = self.catalog.data[bands[0]] - self.catalog.data[bands[-1]] contour_hess(ax, c, self.catalog.data[bands[-1]], plane.x_span, plane.y_span, plot_args={'ms': 3}) ax.set_xlabel(plane.x_label) ax.set_ylabel(plane.y_label) ax.set_xlim(*plane.xlim) ax.set_ylim(*plane.ylim) def plot_sim_hess(self, ax, plane_key): plane = self.planes[plane_key] sim = self.planes.get_sim_hess(plane_key, self.synth, self.lockfile) plot_hess(ax, sim.hess, plane, sim.origin, imshow_args=None) def plot_obs_hess(self, arg1): pass def plot_fit_hess(self, arg1): pass def plot_predicted_hess(self, arg1): pass def plot_triptyk(self, fig, ax_obs, ax_model, ax_chi, fit_key, plane_key, xtick=1., xfmt="%.0f"): cmapper = lambda: cubehelix.cmap(startHue=240, endHue=-300, minSat=1, maxSat=2.5, minLight=.3, maxLight=.8, gamma=.9) fit = self.fits[fit_key] plane = self.planes[plane_key] ctp = ChiTriptykPlot(fit, plane) ctp.setup_axes(fig, ax_obs=ax_obs, ax_mod=ax_model, ax_chi=ax_chi, major_x=xtick, major_x_fmt=xfmt) ctp.plot_obs_in_ax(ax_obs, cmap=cmapper()) ctp.plot_mod_in_ax(ax_model, cmap=cmapper()) ctp.plot_chi_in_ax(ax_chi, cmap=cubehelix.cmap()) ax_obs.text(0.0, 1.01, "Observed", transform=ax_obs.transAxes, size=8, ha='left') ax_model.text(0.0, 1.01, "Model", transform=ax_model.transAxes, size=8, ha='left') ax_chi.text(0.0, 1.01, r"$\log \chi^2$", transform=ax_chi.transAxes, size=8, ha='left') def plot_isoc_grid_ages(self, ax, band1, band2, show_cb=False, cb_ax=None): isoc_set = get_demo_age_grid( **dict(isoc_kind='parsec_CAF09_v1.2S', photsys_version='yang')) cmap = cubehelix.cmap(startHue=240, endHue=-300, minSat=1, maxSat=2.5, minLight=.3, maxLight=.8, gamma=.9) norm = mpl.colors.Normalize(vmin=7., vmax=10.1) scalar_map = mpl.cm.ScalarMappable(norm=norm, cmap=cmap) scalar_map.set_array(np.array([isoc.age for isoc in isoc_set])) d = Distance(785 * u.kpc) for isoc in isoc_set: ax.plot(isoc[band1] - isoc[band2], isoc[band2] + d.distmod.value, c=scalar_map.to_rgba(np.log10(isoc.age)), lw=0.8) if show_cb: cb = plt.colorbar(mappable=scalar_map, cax=cb_ax, ax=ax, ticks=np.arange(6., 10.2)) cb.set_label(r"log(age)") def plot_isoc_grid_phases(self, ax, band1, band2, show_cb=False, cb_ax=None): plot_isochrone_phases(ax, band1, band2, show_cb=show_cb, cb_ax=cb_ax) def show_lockfile(self, fig, logage_lim=(6.2, 10.2), logzzsol_lim=(-0.2, 0.2)): # fig = plt.figure(figsize=(6, 6)) ax = fig.add_subplot(111) plot_isochrone_logage_logzsol(ax, self.builder, c='k', s=8) plot_lock_polygons(ax, self.lockfile, facecolor='None', edgecolor='r') ax.set_xlim(*logage_lim) ax.set_ylim(*logzzsol_lim) ax.set_xlabel(r"$\log(A)$") ax.set_ylabel(r"$\log(Z/Z_\odot)$") fig.show()
class Pipeline(object): """Pipeline for Multi-CMD fitting and comparison""" def __init__(self, brick, root_dir, isoc_args=None, phases=None): super(Pipeline, self).__init__() self.brick = brick self.catalog = Catalog(brick) self.root_dir = root_dir self.isoc_dir = os.path.join(root_dir, 'isoc') self.lib_dir = os.path.join(root_dir, 'lib') self.synth_dir = os.path.join(root_dir, 'synth') self.z_grid = [0.015, 0.019, 0.024] self.get_isochrones(isoc_args=isoc_args, phases=phases) self.build_lockfile() self.planes = PhatPlanes() self.run_synth() self.fits = OrderedDict() self._solution_tables = {} def get_solution_table(self, key): if key not in self._solution_tables: tbl = self.fits[key].solution_table() self._solution_tables[key] = tbl return tbl def get_isochrones(self, isoc_args=None, phases=None): if isoc_args is None: isoc_args = {} if not os.path.exists(os.path.join(STARFISH, self.isoc_dir)): for z in self.z_grid: r_wfc3 = AgeGridRequest(z, min_log_age=6.6, max_log_age=10.13, delta_log_age=0.02, photsys='wfc3_wide', **isoc_args) r_acs = AgeGridRequest(z, min_log_age=6.6, max_log_age=10.13, delta_log_age=0.02, photsys='acs_wfc', **isoc_args) isoc_set = join_isochrone_sets(r_wfc3.isochrone_set, r_acs.isochrone_set, left_bands=WFC3_BANDS, right_bands=ACS_BANDS) for isoc in isoc_set: isoc = Isochrone(isoc) isoc.rename_column('F275W1', 'F275W') if phases is not None: sels = [] for p in phases: sels.append(np.where(isoc['stage'] == p)[0]) s = np.concatenate(sels) isoc = isoc[s] isoc.export_for_starfish(os.path.join(STARFISH, self.isoc_dir), bands=PHAT_BANDS) d = Distance(785 * u.kpc) self.builder = LibraryBuilder(self.isoc_dir, self.lib_dir, nmag=len(PHAT_BANDS), dmod=d.distmod.value, iverb=3) if not os.path.exists(self.builder.full_isofile_path): self.builder.install() def build_lockfile(self): if not os.path.exists(os.path.join(STARFISH, self.synth_dir)): os.makedirs(os.path.join(STARFISH, self.synth_dir)) self.lockfile = Lockfile(self.builder.read_isofile(), self.synth_dir, unbinned=False) # Bin young isochrones young_grid = np.linspace(6.5, 8.95, 10) for i, logage0 in enumerate(young_grid[:-1]): logage0 = logage0 logage1 = young_grid[i + 1] z_str = "0019" mean_age = (logage0 + logage1) / 0.2 name = "z{0}_{1:05.2f}".format(z_str, mean_age) self.lockfile.lock_box(name, (logage0, logage1), (0.014, 0.025)) # Bin old isochrones old_grid = np.arange(1e9, 14 * 1e9, 1e9) for i, age0 in enumerate(old_grid[:-1]): logage0 = np.log10(age0 - 0.05 * 1e9) logage1 = np.log10(old_grid[i + 1]) z_str = "0019" mean_age = (logage0 + logage1) / 0.2 name = "z{0}_{1:05.2f}".format(z_str, mean_age) self.lockfile.lock_box(name, (logage0, logage1), (0.014, 0.025)) def run_synth(self, planes=None, force=False): full_synth_dir = os.path.join(STARFISH, self.synth_dir) if not os.path.exists(full_synth_dir): os.makedirs(full_synth_dir) # Use PHAT AST from the outer field (field 0) crowd_path = os.path.join(self.synth_dir, "crowding.dat") full_crowd_path = os.path.join(STARFISH, crowd_path) tbl = PhatAstTable() tbl.write_crowdfile_for_field(full_crowd_path, 0, bands=PHAT_BANDS) crowd = ExtantCrowdingTable(crowd_path) # No extinction, yet young_av = ExtinctionDistribution() old_av = ExtinctionDistribution() rel_extinction = np.ones(len(PHAT_BANDS), dtype=float) for av in (young_av, old_av): av.set_uniform(0.) if planes is None: planes = self.planes.all_planes self.synth = Synth(self.synth_dir, self.builder, self.lockfile, crowd, rel_extinction, young_extinction=young_av, old_extinction=old_av, planes=planes, mass_span=(0.08, 150.), nstars=10000000) existing_synth = len(glob( os.path.join(STARFISH, self.synth_dir, "z*"))) == 0 if existing_synth or force: self.synth.run_synth(n_cpu=4, clean=False) def fit_planes(self, key, color_planes, phot_colors, redo=False): fit_dir = os.path.join(self.root_dir, key) data_root = os.path.join(fit_dir, "phot.") for plane, (band1, band2) in zip(color_planes, phot_colors): self.catalog.write(band1, band2, data_root, plane.suffix) sfh = SFH(data_root, self.synth, fit_dir, planes=color_planes) if (not os.path.exists(sfh.full_outfile_path)) or redo: sfh.run_sfh() self.fits[key] = sfh def show_isoc_phase_sim_hess(self, fig): opt_sim = self.planes.get_sim_hess(('f475w', 'f814w'), self.synth, self.lockfile) ir_sim = self.planes.get_sim_hess(('f110w', 'f160w'), self.synth, self.lockfile) opt_cmd = self.planes[('f475w', 'f814w')] ir_cmd = self.planes[('f110w', 'f160w')] gs = gridspec.GridSpec(2, 3, wspace=0.4, bottom=0.2, width_ratios=[1., 1., 0.1]) ax_opt = fig.add_subplot(gs[0, 0]) ax_ir = fig.add_subplot(gs[0, 1]) ax_obs_opt = fig.add_subplot(gs[1, 0]) ax_obs_ir = fig.add_subplot(gs[1, 1]) cb_ax = fig.add_subplot(gs[1, 2]) plot_hess(ax_opt, opt_sim.hess, opt_cmd, opt_sim.origin, imshow_args=None) plot_hess(ax_ir, ir_sim.hess, ir_cmd, ir_sim.origin, imshow_args=None) c = self.catalog.data['f475w_vega'] - self.catalog.data['f814w_vega'] contour_hess(ax_obs_opt, c, self.catalog.data['f814w_vega'], opt_cmd.x_span, opt_cmd.y_span, plot_args={'ms': 3}) plot_isochrone_phases(ax_obs_opt, 'F475W', 'F814W', show_cb=False) # opt_cmd.plot_mask(ax_obs_opt) ax_obs_opt.set_xlabel(opt_cmd.x_label) ax_obs_opt.set_ylabel(opt_cmd.y_label) ax_obs_opt.set_xlim(opt_cmd.xlim) ax_obs_opt.set_ylim(opt_cmd.ylim) c = self.catalog.data['f110w_vega'] - self.catalog.data['f160w_vega'] contour_hess(ax_obs_ir, c, self.catalog.data['f160w_vega'], ir_cmd.x_span, ir_cmd.y_span, plot_args={'ms': 3}) plot_isochrone_phases(ax_obs_ir, 'F110W', 'F160W', show_cb=True, cb_ax=cb_ax) # ir_cmd.plot_mask(ax_obs_ir) ax_obs_ir.set_xlabel(ir_cmd.x_label) ax_obs_ir.set_ylabel(ir_cmd.y_label) ax_obs_ir.set_xlim(ir_cmd.xlim) ax_obs_ir.set_ylim(ir_cmd.ylim) fig.show() def plot_contour_hess(self, ax, bands, plane_key): plane = self.planes[plane_key] c = self.catalog.data[bands[0]] - self.catalog.data[bands[-1]] contour_hess(ax, c, self.catalog.data[bands[-1]], plane.x_span, plane.y_span, plot_args={'ms': 3}) ax.set_xlabel(plane.x_label) ax.set_ylabel(plane.y_label) ax.set_xlim(*plane.xlim) ax.set_ylim(*plane.ylim) def plot_sim_hess(self, ax, plane_key): plane = self.planes[plane_key] sim = self.planes.get_sim_hess(plane_key, self.synth, self.lockfile) plot_hess(ax, sim.hess, plane, sim.origin, imshow_args=None) def plot_obs_hess(self, arg1): pass def plot_fit_hess(self, arg1): pass def plot_predicted_hess(self, arg1): pass def plot_triptyk(self, fig, ax_obs, ax_model, ax_chi, fit_key, plane_key, xtick=1., xfmt="%.0f"): cmapper = lambda: cubehelix.cmap(startHue=240, endHue=-300, minSat=1, maxSat=2.5, minLight=.3, maxLight=.8, gamma=.9) fit = self.fits[fit_key] plane = self.planes[plane_key] ctp = ChiTriptykPlot(fit, plane) ctp.setup_axes(fig, ax_obs=ax_obs, ax_mod=ax_model, ax_chi=ax_chi, major_x=xtick, major_x_fmt=xfmt) ctp.plot_obs_in_ax(ax_obs, cmap=cmapper()) ctp.plot_mod_in_ax(ax_model, cmap=cmapper()) ctp.plot_chi_in_ax(ax_chi, cmap=cubehelix.cmap()) ax_obs.text(0.0, 1.01, "Observed", transform=ax_obs.transAxes, size=8, ha='left') ax_model.text(0.0, 1.01, "Model", transform=ax_model.transAxes, size=8, ha='left') ax_chi.text(0.0, 1.01, r"$\log \chi^2$", transform=ax_chi.transAxes, size=8, ha='left') def plot_isoc_grid_ages(self, ax, band1, band2, show_cb=False, cb_ax=None): isoc_set = get_demo_age_grid(**dict(isoc_kind='parsec_CAF09_v1.2S', photsys_version='yang')) cmap = cubehelix.cmap(startHue=240, endHue=-300, minSat=1, maxSat=2.5, minLight=.3, maxLight=.8, gamma=.9) norm = mpl.colors.Normalize(vmin=7., vmax=10.1) scalar_map = mpl.cm.ScalarMappable(norm=norm, cmap=cmap) scalar_map.set_array(np.array([isoc.age for isoc in isoc_set])) d = Distance(785 * u.kpc) for isoc in isoc_set: ax.plot(isoc[band1] - isoc[band2], isoc[band2] + d.distmod.value, c=scalar_map.to_rgba(np.log10(isoc.age)), lw=0.8) if show_cb: cb = plt.colorbar(mappable=scalar_map, cax=cb_ax, ax=ax, ticks=np.arange(6., 10.2)) cb.set_label(r"log(age)") def plot_isoc_grid_phases(self, ax, band1, band2, show_cb=False, cb_ax=None): plot_isochrone_phases(ax, band1, band2, show_cb=show_cb, cb_ax=cb_ax) def show_lockfile(self, fig, logage_lim=(6.2, 10.2), logzzsol_lim=(-0.2, 0.2)): # fig = plt.figure(figsize=(6, 6)) ax = fig.add_subplot(111) plot_isochrone_logage_logzsol(ax, self.builder, c='k', s=8) plot_lock_polygons(ax, self.lockfile, facecolor='None', edgecolor='r') ax.set_xlim(*logage_lim) ax.set_ylim(*logzzsol_lim) ax.set_xlabel(r"$\log(A)$") ax.set_ylabel(r"$\log(Z/Z_\odot)$") fig.show()
class PipelineBase(object): """Abstract baseclass for running StarFISH pipelines.""" __metaclass__ = abc.ABCMeta def __init__(self, **kwargs): self.root_dir = kwargs.pop('root_dir') self.n_synth_cpu = kwargs.pop('n_synth_cpu', 1) self.synth_config_only = kwargs.pop('synth_config_only', False) # StarFISH product directories self.isoc_dir = os.path.join(self.root_dir, 'isoc') self.lib_dir = os.path.join(self.root_dir, 'lib') self.synth_dir = os.path.join(self.root_dir, 'synth') # result caches self.fits = OrderedDict() self._solution_tables = {} print "PipelineBase", kwargs if len(kwargs) > 0: print "Uncaught arguments:", kwargs super(PipelineBase, self).__init__() dirs = (self.isoc_dir, self.lib_dir, self.synth_dir) for d in dirs: if not os.path.exists(os.path.join(STARFISH, d)): os.makedirs(os.path.join(STARFISH, d)) print self.isoc_dir, self.lib_dir, self.synth_dir # Now run the pipeline self.setup_isochrones() print "self.builder.full_isofile_path", self.builder.full_isofile_path self.build_lockfile() self.build_crowding() self.build_extinction() self.mask_planes() # mask planes based on completeness cuts if self.synth_config_only: self.run_synth(n_cpu=1, config_only=self.synth_config_only) else: self.run_synth(n_cpu=self.n_synth_cpu, config_only=self.synth_config_only) def run_synth(self, n_cpu=1, config_only=False): full_synth_dir = os.path.join(STARFISH, self.synth_dir) self.synth = Synth(self.synth_dir, self.builder, self.lockfile, self.crowd, self.bands, self.rel_extinction, young_extinction=self.young_av, old_extinction=self.old_av, planes=self.all_planes, mass_span=(0.08, 150.), nstars=10000000) existing_synth = len(glob( os.path.join(full_synth_dir, "z*"))) > 0 if not config_only: if not existing_synth: self.synth.run_synth(n_cpu=n_cpu, clean=False) @abc.abstractmethod def mask_planes(self): pass @property def hold_template(self): return self.lockfile.empty_hold def fit(self, fit_key, plane_keys, dataset, fit_dir=None, redo=False, hold=None): if fit_dir is None: fit_dir = os.path.join(self.root_dir, fit_key) data_root = os.path.join(fit_dir, "phot.") planes = [] for plane_key in plane_keys: plane = self.planes[plane_key] planes.append(plane) dataset.write_phot(plane.x_mag, plane.y_mag, data_root, plane.suffix) sfh = SFH(data_root, self.synth, fit_dir, planes=planes) if (not os.path.exists(sfh.full_outfile_path)) or redo: sfh.run_sfh(hold=hold) self.fits[fit_key] = sfh def make_fit_diff_hess(self, dataset, fit_key, plane_key): obs_hess = self.make_obs_hess(dataset, plane_key) fit_hess = self.make_fit_hess(fit_key, plane_key) return Hess(obs_hess.hess - fit_hess.hess, self.planes[plane_key]) def make_chisq_hess(self, dataset, fit_key, plane_key): obs_hess = self.make_obs_hess(dataset, plane_key) fit_hess = self.make_fit_hess(fit_key, plane_key) sigma = np.sqrt(obs_hess.hess) chi = ((obs_hess.hess - fit_hess.hess) / sigma) ** 2. return Hess(chi, self.planes[plane_key]) def compute_fit_chi(self, dataset, fit_key, plane_key, chi_hess=None): """Compute the reduced chi-sq for the plane with the given fit. Returns both the sum of chi-sq and the total number of pixels in in the plane that were not masked. """ if chi_hess is None: chi_hess = self.make_chisq_hess(dataset, fit_key, plane_key) g = np.where(np.isfinite(chi_hess.masked_hess)) n_pix = len(g[0]) chi_sum = chi_hess.masked_hess[g].sum() n_amp = len(self.lockfile.active_groups) return chi_sum / (n_pix - n_amp) def make_sim_hess(self, plane_key): return self.get_sim_hess(plane_key) def make_fit_hess(self, fit_key, plane_key): plane = self.planes[plane_key] return SimHess.from_sfh_solution(self.fits[fit_key], plane) def make_obs_hess(self, dataset, plane_key): plane = self.planes[plane_key] x = dataset.get_phot(plane.x_mag) y = dataset.get_phot(plane.y_mag) return StarCatalogHess(x, y, plane) def init_plane_axes(self, ax, plane_key): plane = self.planes[plane_key] setup_hess_axes(ax, plane, 'lower') def plot_sim_hess(self, ax, plane_key, imshow=None): plane = self.planes[plane_key] sim = self.get_sim_hess(plane_key) return plot_hess(ax, sim.hess, plane, sim.origin, imshow_args=imshow) def plot_fit_hess(self, ax, fit_key, plane_key, imshow=None): plane = self.planes[plane_key] fit_hess = SimHess.from_sfh_solution(self.fits[fit_key], plane) return plot_hess(ax, fit_hess.hess, plane, fit_hess.origin, imshow_args=imshow) def plot_obs_hess(self, ax, dataset, plane_key, imshow=None): plane = self.planes[plane_key] x = dataset.get_phot(plane.x_mag) y = dataset.get_phot(plane.y_mag) obs_hess = StarCatalogHess(x, y, plane) return plot_hess(ax, obs_hess.hess, plane, obs_hess.origin, imshow_args=imshow) def plot_hess_array(self, ax, hess, plane_key, imshow=None, log=True): plane = self.planes[plane_key] return plot_hess(ax, hess, plane, plane.origin, imshow_args=imshow, log=log) def plot_lockfile(self, ax, logage_lim=(6.2, 10.2), logzzsol_lim=(-0.2, 0.2)): plot_isochrone_logage_logzsol(ax, self.builder, c='k', s=8) plot_lock_polygons(ax, self.lockfile, facecolor='None', edgecolor='r') ax.set_xlim(*logage_lim) ax.set_ylim(*logzzsol_lim) ax.set_xlabel(r"$\log(A)$") ax.set_ylabel(r"$\log(Z/Z_\odot)$") def plot_triptyk(self, fig, ax_obs, ax_model, ax_chi, fit_key, plane_key, xtick=1., xfmt="%.0f"): fit = self.fits[fit_key] plane = self.planes[plane_key] ctp = ChiTriptykPlot(fit, plane) ctp.setup_axes(fig, ax_obs=ax_obs, ax_mod=ax_model, ax_chi=ax_chi, major_x=xtick, major_x_fmt=xfmt) ctp.plot_obs_in_ax(ax_obs, cmap=perceptual_rainbow_16.mpl_colormap) ctp.plot_mod_in_ax(ax_model, cmap=perceptual_rainbow_16.mpl_colormap) ctp.plot_chi_in_ax(ax_chi, cmap=perceptual_rainbow_16.mpl_colormap) ax_obs.text(0.0, 1.01, "Observed", transform=ax_obs.transAxes, size=8, ha='left') ax_model.text(0.0, 1.01, "Model", transform=ax_model.transAxes, size=8, ha='left') ax_chi.text(0.0, 1.01, r"$\log \chi^2$", transform=ax_chi.transAxes, size=8, ha='left') def plot_linear_sfh_circles(self, ax, fit_key, ylim=(-0.2, 0.2), amp_key='sfr'): sfh = self.fits[fit_key] cp = LinearSFHCirclePlot(sfh.solution_table()) cp.plot_in_ax(ax, max_area=800, amp_key=amp_key) for tl in ax.get_ymajorticklabels(): tl.set_visible(False) ax.set_ylim(*ylim) def plot_log_sfh_circles(self, ax, fit_key, ylim=(-0.2, 0.2), amp_key='sfr'): sfh = self.fits[fit_key] cp = SFHCirclePlot(sfh.solution_table()) cp.plot_in_ax(ax, max_area=800, amp_key=amp_key) for logage in np.log10(np.arange(1, 13, 1) * 1e9): ax.axvline(logage, c='0.8', zorder=-1) ax.set_ylim(*ylim)